-
Notifications
You must be signed in to change notification settings - Fork 162
Slurm support for QAT Simplified Flow + Qwen3-8B recipe #285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #285 +/- ##
=======================================
Coverage 73.82% 73.82%
=======================================
Files 172 172
Lines 17438 17438
=======================================
+ Hits 12873 12874 +1
+ Misses 4565 4564 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
WalkthroughRefactors the NeMo QAT/QAD flow into a Slurm-capable multi-stage pipeline, adds Slurm and dataset utilities, an OpenScience-to-chat processor, an in-memory MMLU runner, and a NeMo-run checkpoint export helper; updates QAT docs and README entries. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Flow as nemo_qat_flow.py
participant Exec as Executor (Local/Slurm)
participant Proc as OpenScience Processor
participant Import as BF16 Import
participant MMLU1 as MMLU Eval (BF16)
participant PTQ as PTQ Job
participant MMLU2 as MMLU Eval (PTQ)
participant Train as QAT/QAD Trainer
participant MMLU3 as MMLU Eval (SFT)
participant Export as export_most_recent_ckpt
User->>Flow: run with args (exp_dir, algorithm, --use-slurm, parallelism...)
Flow->>Exec: configure executors (CPU / PTQ-GPU / Train-GPU)
Flow->>Proc: download & process OpenScience -> produce train/val JSONL
Proc-->>Flow: training.jsonl, validation.jsonl
Flow->>Import: restore BF16 checkpoint
Import-->>Flow: bf16 checkpoint path
Flow->>MMLU1: evaluate BF16
MMLU1-->>Flow: BF16 metrics
Flow->>PTQ: run PTQ (algorithm, kv-cache options)
PTQ-->>Flow: quantized artifacts
Flow->>MMLU2: evaluate PTQ model
MMLU2-->>Flow: PTQ metrics
Flow->>Train: run QAT/QAD training (recipe, lr, devices)
Train-->>Flow: SFT checkpoints
Flow->>MMLU3: evaluate SFT model
MMLU3-->>Flow: SFT metrics
Flow->>Export: export_most_recent_ckpt(exp_dir, hf_out)
Export-->>Flow: HF export
Flow-->>User: logs, metrics, exported model
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60–90 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 8
🧹 Nitpick comments (20)
modelopt/torch/export/plugins/nemo_run.py (3)
38-47
: Type + stability nit: annotate return type and prefer higher-resolution mtime.Add the return type and consider st_mtime_ns for better tie-breaking.
-def _get_most_recent_subdir(directory: Path): +def _get_most_recent_subdir(directory: Path) -> Path: @@ - most_recent = max(subdirs, key=lambda x: x.stat().st_mtime) + most_recent = max(subdirs, key=lambda x: x.stat().st_mtime_ns)
24-35
: Log destination and resolve absolute paths for clarity.Small QoL: include output_path in logs; resolve paths to avoid ambiguity.
-def export_most_recent_ckpt(directory: str, output_path: str): +def export_most_recent_ckpt(directory: str, output_path: str): @@ - logging.info(f"Exporting most recent NeMo Run checkpoint: {most_recent_ckpt}") + most_recent_ckpt = str(Path(most_recent_ckpt).resolve()) + output_path = str(Path(output_path).resolve()) + logging.info(f"Exporting most recent NeMo Run checkpoint: {most_recent_ckpt} -> {output_path}")
50-71
: Avoid importing private helpers across modules; provide a public wrapper.
examples/nemo_run/common/in_memory_mmlu.py
imports_get_most_recent_ckpt
. Expose a public symbol to decouple callers from private helpers.def _get_most_recent_ckpt(directory: str): @@ return str(most_recent) + +def get_most_recent_ckpt(directory: str) -> str: + """Public wrapper for resolving the most recent NeMo Run checkpoint directory.""" + return _get_most_recent_ckpt(directory)examples/nemo_run/common/in_memory_mmlu.py (2)
20-21
: Stop importing a private symbol; switch to a public helper.Once
get_most_recent_ckpt
is added, import and use it here.-from modelopt.torch.export.plugins.nemo_run import _get_most_recent_ckpt +from modelopt.torch.export.plugins.nemo_run import get_most_recent_ckpt @@ - ckpt_path = _get_most_recent_ckpt(args.ckpt_dir) + ckpt_path = get_most_recent_ckpt(args.ckpt_dir)Also applies to: 49-49
46-49
: Simplify ckpt selection after making args mutually exclusive.No need for an assert or pre-init; pick based on which arg is set.
- assert args.nemo_ckpt or args.ckpt_dir, "Provide one of either --nemo_ckpt or --ckpt_dir." - ckpt_path = args.nemo_ckpt - if args.ckpt_dir: - ckpt_path = _get_most_recent_ckpt(args.ckpt_dir) + ckpt_path = args.nemo_ckpt if args.nemo_ckpt else get_most_recent_ckpt(args.ckpt_dir)examples/nemo_run/common/process_openscience.py (3)
17-19
: Remove unused import.-import json import os from pathlib import Path
58-59
: Ensure parent directories exist when creating the processed dir.- Path(proc_dir).mkdir(exist_ok=True) + Path(proc_dir).mkdir(parents=True, exist_ok=True)
43-45
: Add a split seed for reproducibility.- split_ds = ds["train"].train_test_split(test_size=0.1) + split_ds = ds["train"].train_test_split(test_size=0.1, seed=42)examples/nemo_run/qat/README.md (6)
11-17
: Fix ordered list numbering.Use 1..6 to silence linters and improve readability.
-1. Process Nvidia/OpenScience data (if `--data-path` is not specified) -1. Import NeMo BF16 model checkpoint and evaluate 5% of MMLU on BF16 checkpoint -1. PTQ the model and evaluate 5% of MMLU on PTQ Checkpoint -1. SFT (finetune) the model -1. Evaluate 5% of MMLU on the SFT checkpoint -1. Export model to Unified checkpoint (HuggingFace) format in lower precision +1. Process NVIDIA/OpenScience data (if `--data-path` is not specified) +2. Import NeMo BF16 model checkpoint and evaluate 5% of MMLU on the BF16 checkpoint +3. PTQ the model and evaluate 5% of MMLU on the PTQ checkpoint +4. SFT (fine-tune) the model +5. Evaluate 5% of MMLU on the SFT checkpoint +6. Export model to Unified checkpoint (Hugging Face) format in lower precision
44-44
: Grammar: duplicate “on”.-To run the example locally, launch a [NeMo container](...) with version 25.07 or higher using Docker on on a Slurm interactive node. +To run the example locally, launch a [NeMo container](...) with version 25.07 or higher using Docker on a Slurm interactive node.
50-51
: Capitalize “Slurm” and tighten wording.-To run the example on slurm, edit the `SLURM_CONFIG` ... +To run the example on Slurm, edit the `SLURM_CONFIG` ...
85-88
: Specify a language for fenced block (lint).-``` +```text qat_flow_ckpts qat_flow_ckpts_1755708286--- `91-128`: **Specify a language for directory tree (lint).** ```diff -``` +```text ├── 00_openscience_data ...
--- `132-132`: **Minor grammar cleanup.** ```diff -By default the script will use the model/tokenizer's chat template, which may not contain the `{% generation %}` and `{% endgeneration %}` tags around the assistant tokens which are needed to generate the assistant loss mask (see [this PR](https://github.com/huggingface/transformers/pull/30650)). To provide path to a custom chat template, use the `--chat-template <my_template.txt>` flag. +By default, the script uses the model/tokenizer's chat template, which may not contain the `{% generation %}` and `{% endgeneration %}` tags around the assistant tokens needed to generate the assistant loss mask (see [this PR](https://github.com/huggingface/transformers/pull/30650)). To provide a path to a custom chat template, use the `--chat-template <my_template.txt>` flag.
examples/nemo_run/common/utils.py (2)
52-66
: Require job_dir also for LocalTunnel.
LocalTunnel
still needs a job directory; validate consistently.- if not self.use_local_tunnel: + if not self.use_local_tunnel: # Only validate SSH tunnel settings if not using local tunnel @@ - if not self.job_dir: - raise ValueError( - "SlurmConfig.job_dir must be set to directory for storing runs on cluster" - ) + if not self.job_dir: + raise ValueError("SlurmConfig.job_dir must be set to directory for storing runs on cluster") + else: + if not self.job_dir: + raise ValueError("SlurmConfig.job_dir must be set when use_local_tunnel=True")
126-129
: Specify encoding when reading templates.-def read_chat_template(template_path: str): - with open(template_path) as f: +def read_chat_template(template_path: str): + with open(template_path, encoding="utf-8") as f: return f.read().strip()examples/nemo_run/qat/nemo_qat_flow.py (4)
245-245
: Remove debug limiter on validation.limit_val_batches=2 will skew metrics. Either make it a CLI flag for dev only or remove.
- train.trainer.limit_val_batches = 2 # TODO remove + # Consider exposing via CLI for quick dev runs: + # train.trainer.limit_val_batches = args.limit_val_batches
219-230
: Constants defined under main are used inside main; import-time usage will crash.If this module is imported and main(args) is called, SEQUENCE_LENGTH/GBS/MBS/TRAIN_STEPS/VAL_INTERVAL won’t exist. Hoist them to module scope or make them CLI args.
Also applies to: 367-372
29-31
: Avoid sys.path manipulation for intra-repo imports.Prefer packaging examples as a module or using relative imports via an installed editable package to remove path hacks.
343-363
: SLURM_CONFIG lifecycle and defaults.SLURM_CONFIG exists only under main and only when --use-slurm is passed. If someone imports and calls main(args) with use_slurm=True, this will NameError. Also, verify that time="240" matches your site policy (some clusters require HH:MM:SS) and ensure HF_TOKEN isn’t logged.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (6)
examples/nemo_run/common/in_memory_mmlu.py
(1 hunks)examples/nemo_run/common/process_openscience.py
(1 hunks)examples/nemo_run/common/utils.py
(1 hunks)examples/nemo_run/qat/README.md
(2 hunks)examples/nemo_run/qat/nemo_qat_flow.py
(6 hunks)modelopt/torch/export/plugins/nemo_run.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
examples/nemo_run/common/in_memory_mmlu.py (2)
modelopt/torch/export/plugins/nemo_run.py (1)
_get_most_recent_ckpt
(50-71)modelopt/torch/utils/plugins/megatron_mmlu.py (1)
megatron_mmlu
(65-152)
examples/nemo_run/common/utils.py (1)
examples/nemo_run/common/process_lima.py (1)
download_hf_dataset
(29-37)
examples/nemo_run/qat/nemo_qat_flow.py (2)
examples/nemo_run/common/utils.py (4)
SlurmConfig
(24-72)create_slurm_executor
(75-117)get_finetune_recipe
(120-123)read_chat_template
(126-128)examples/nemo_run/common/in_memory_mmlu.py (1)
parse_args
(24-41)
🪛 LanguageTool
examples/nemo_run/qat/README.md
[grammar] ~12-~12: There might be a mistake here.
Context: ...odel checkpoint and evaluate 5% of MMLU on BF16 checkpoint 1. PTQ the model and ev...
(QB_NEW_EN)
[grammar] ~13-~13: There might be a mistake here.
Context: .... PTQ the model and evaluate 5% of MMLU on PTQ Checkpoint 1. SFT (finetune) the mo...
(QB_NEW_EN)
[grammar] ~16-~16: There might be a mistake here.
Context: ...ate 5% of MMLU on the SFT checkpoint 1. Export model to Unified checkpoint (HuggingFac...
(QB_NEW_EN)
[grammar] ~35-~35: There might be a mistake here.
Context: ... following model: - Model: Qwen3-8B - Recipe: qwen3_8b ## Usage ### Prereq...
(QB_NEW_EN)
[grammar] ~44-~44: There might be a mistake here.
Context: ...version 25.07 or higher using Docker on on a Slurm interactive node. Mount your cl...
(QB_NEW_EN)
[style] ~53-~53: Consider using a less common alternative to make your writing sound more unique and professional.
Context: ... recover the accuracy of the PTQ model. Feel free to try your own model and dataset combinat...
(FEEL_FREE_TO_STYLE_ME)
[style] ~62-~62: Three successive sentences begin with the same word. Consider rewording the sentence or use a thesaurus to find a synonym.
Context: ...dia/OpenScience) dataset will be used. To perform QAT, run: ```bash python qat/n...
(ENGLISH_WORD_REPEAT_BEGINNING_RULE)
[grammar] ~132-~132: There might be a mistake here.
Context: ...ggingface/transformers/pull/30650)). To provide path to a custom chat template, use the...
(QB_NEW_EN)
[grammar] ~132-~132: There might be a mistake here.
Context: ...--chat-template <my_template.txt>` flag.
(QB_NEW_EN)
🪛 markdownlint-cli2 (0.17.2)
examples/nemo_run/qat/README.md
52-52: Headings should be surrounded by blank lines
Expected: 1; Actual: 0; Below
(MD022, blanks-around-headings)
85-85: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
91-91: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (2)
examples/nemo_run/common/in_memory_mmlu.py (1)
55-56
: Confirm model.module is always present.Depending on how the model is wrapped,
.module
may be absent on single-GPU runs.Please run a quick smoke test on both single-GPU and DDP to ensure
model.module
is valid; otherwise passmodel
directly when.module
is missing.examples/nemo_run/qat/nemo_qat_flow.py (1)
170-178
: HuggingFace offline mode caveat.With TRANSFORMERS_OFFLINE=1 in SlurmConfig defaults, importing from hf:// may fail unless the model/tokenizer are pre-cached inside the container. Confirm cache availability or override the env var.
Also applies to: 202-211
def create_slurm_executor( | ||
slurm_cfg: SlurmConfig, nodes: int = 1, ntasks_per_node: int = 1, num_gpus: int = 0 | ||
): | ||
# Configure tunnel | ||
if slurm_cfg.use_local_tunnel: | ||
# Use LocalTunnel when already on the cluster | ||
tunnel = run.LocalTunnel(job_dir=slurm_cfg.job_dir) | ||
else: | ||
# Use SSH tunnel when launching from local machine | ||
tunnel = run.SSHTunnel( | ||
host=slurm_cfg.host, | ||
user=slurm_cfg.user, | ||
job_dir=slurm_cfg.job_dir, | ||
identity=slurm_cfg.identity, # can be None | ||
) | ||
|
||
if num_gpus > 0: | ||
return run.SlurmExecutor( | ||
account=slurm_cfg.account, | ||
partition=slurm_cfg.partition_gpu, | ||
ntasks_per_node=ntasks_per_node, | ||
gpus_per_node=num_gpus, | ||
nodes=nodes, | ||
tunnel=tunnel, | ||
container_image=slurm_cfg.container_image, | ||
container_mounts=slurm_cfg.container_mounts, | ||
time=slurm_cfg.time, | ||
packager=run.GitArchivePackager(), | ||
mem="0", | ||
gres=f"gpu:{num_gpus}", | ||
) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Propagate env vars to Slurm jobs.
SlurmConfig.env_vars
is never applied; pass it to the executor (param name may differ by nemo-run version).
- return run.SlurmExecutor(
+ return run.SlurmExecutor(
account=slurm_cfg.account,
partition=slurm_cfg.partition_gpu,
ntasks_per_node=ntasks_per_node,
gpus_per_node=num_gpus,
nodes=nodes,
tunnel=tunnel,
container_image=slurm_cfg.container_image,
container_mounts=slurm_cfg.container_mounts,
time=slurm_cfg.time,
packager=run.GitArchivePackager(),
mem="0",
gres=f"gpu:{num_gpus}",
+ env_vars=slurm_cfg.env_vars, # verify parameter name: env_vars/env
)
@@
- return run.SlurmExecutor(
+ return run.SlurmExecutor(
account=slurm_cfg.account,
partition=slurm_cfg.partition_cpu,
nodes=nodes,
tunnel=tunnel,
container_image=slurm_cfg.container_image,
container_mounts=slurm_cfg.container_mounts,
time=slurm_cfg.time,
packager=run.GitArchivePackager(),
mem="0",
+ env_vars=slurm_cfg.env_vars, # verify parameter name: env_vars/env
)
If constructor doesn’t support it, set on the instance if available (e.g., executor.env_vars.update(...)
).
Also applies to: 107-117
🏁 Script executed:
#!/usr/bin/env bash
set -e
# 1. Show import lines to determine how `run` is imported
sed -n '1,30p' examples/nemo_run/common/utils.py
# 2. Search repository for a local SlurmExecutor definition
rg -nP --type=py "class\s+SlurmExecutor" -n -C3 .
# 3. If not found locally, install the package and introspect its signature
pip install nemo-run
python3 - << 'PYCODE'
import inspect
try:
from nemo_run.common import run
except ImportError:
import run
sig = inspect.signature(run.SlurmExecutor)
print("SlurmExecutor signature:", sig)
PYCODE
Length of output: 1194
🏁 Script executed:
#!/usr/bin/env bash
set -e
# 1. Confirm no local SlurmExecutor in repo
rg -nP --type=py "class\s+SlurmExecutor" -n -C3 .
# 2. Install nemo-run and inspect SlurmExecutor signature
pip install nemo-run
python3 - << 'PYCODE'
import inspect
import nemo_run as run
sig = inspect.signature(run.SlurmExecutor)
print("Signature:", sig)
print("Parameters:", list(sig.parameters.keys()))
PYCODE
Length of output: 60
🏁 Script executed:
#!/usr/bin/env bash
# Re-verify SlurmExecutor existence and signature without exiting on errors.
# 1. Check for local SlurmExecutor; ignore if none found.
rg -nP --type=py "class\s+SlurmExecutor" -n -C3 . || echo "No local SlurmExecutor definition in repo."
# 2. Install nemo-run package (suppressing verbose output).
pip install nemo-run >/dev/null 2>&1
# 3. Introspect the SlurmExecutor constructor signature and doc.
python3 - << 'PYCODE'
import inspect
import nemo_run as run
sig = inspect.signature(run.SlurmExecutor)
params = sig.parameters
print("Signature:", sig)
print("Parameter names:", list(params.keys()))
# Print each parameter with default if any
for name, param in params.items():
print(f" - {name}: default={param.default!r}")
PYCODE
Length of output: 4059
🏁 Script executed:
#!/usr/bin/env bash
set -e
# Locate SlurmConfig definition
grep -R "class SlurmConfig" -n examples/nemo_run/common/utils.py
# Display SlurmConfig fields
sed -n '1,200p' examples/nemo_run/common/utils.py | sed -n '/class SlurmConfig/,/^$/p'
# Search for env_vars usage in the file
grep -R "env_vars" -n examples/nemo_run/common/utils.py
Length of output: 427
Pass env_vars
from SlurmConfig
to SlurmExecutor
Add env_vars=slurm_cfg.env_vars
to both run.SlurmExecutor(…)
calls in the GPU and CPU branches of create_slurm_executor
(examples/nemo_run/common/utils.py) so that user-configured environment variables are applied.
🤖 Prompt for AI Agents
In examples/nemo_run/common/utils.py around lines 75 to 106, the SlurmExecutor
calls do not pass the user-configured environment variables from SlurmConfig;
add env_vars=slurm_cfg.env_vars to both the GPU branch and the CPU (else) branch
SlurmExecutor(...) invocations so the executor is launched with the configured
environment variables.
def get_finetune_recipe(recipe_name: str): | ||
if not hasattr(getattr(llm, recipe_name), "finetune_recipe"): | ||
raise ValueError(f"Recipe {recipe_name} does not have a Fine-Tuning recipe") | ||
return getattr(llm, recipe_name).finetune_recipe(peft_scheme=None) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix AttributeError risk when recipe module is missing.
getattr(llm, recipe_name)
will raise if the module doesn’t exist; guard it first.
-def get_finetune_recipe(recipe_name: str):
- if not hasattr(getattr(llm, recipe_name), "finetune_recipe"):
- raise ValueError(f"Recipe {recipe_name} does not have a Fine-Tuning recipe")
- return getattr(llm, recipe_name).finetune_recipe(peft_scheme=None)
+def get_finetune_recipe(recipe_name: str):
+ recipe_mod = getattr(llm, recipe_name, None)
+ if recipe_mod is None or not hasattr(recipe_mod, "finetune_recipe"):
+ raise ValueError(f"Recipe {recipe_name} does not exist or lacks a fine-tuning recipe")
+ return recipe_mod.finetune_recipe(peft_scheme=None)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def get_finetune_recipe(recipe_name: str): | |
if not hasattr(getattr(llm, recipe_name), "finetune_recipe"): | |
raise ValueError(f"Recipe {recipe_name} does not have a Fine-Tuning recipe") | |
return getattr(llm, recipe_name).finetune_recipe(peft_scheme=None) | |
def get_finetune_recipe(recipe_name: str): | |
recipe_mod = getattr(llm, recipe_name, None) | |
if recipe_mod is None or not hasattr(recipe_mod, "finetune_recipe"): | |
raise ValueError(f"Recipe {recipe_name} does not exist or lacks a fine-tuning recipe") | |
return recipe_mod.finetune_recipe(peft_scheme=None) |
🤖 Prompt for AI Agents
In examples/nemo_run/common/utils.py around lines 120 to 124, the code calls
getattr(llm, recipe_name) directly which will raise AttributeError if the recipe
module is missing; first check that the recipe exists (e.g., use hasattr(llm,
recipe_name) or getattr(llm, recipe_name, None) and verify it’s not None), then
ensure the found object has a finetune_recipe attribute before accessing it; if
the recipe is missing or lacks finetune_recipe, raise a clear ValueError,
otherwise call finetune_recipe(peft_scheme=None).
2edf3fa
to
0eac8f6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (9)
examples/llm_qat/README.md (1)
14-14
: Add Docs link for consistency (optional).Consider filling the “Docs” column with a pointer to the new NeMo QAT/QAD guide (e.g., ../nemo_run/qat/README.md) to match other rows that include docs.
examples/nemo_run/qat/ADVANCED.md (3)
11-11
: Document how to install rsync inside the container.Provide an explicit install step or a container tag that includes rsync.
Add:
# Inside container (Debian/Ubuntu base) apt-get update && apt-get install -y rsync openssh-clientOr note the minimal container version that bundles rsync.
13-16
: Specify fenced code language.Add a language to satisfy MD040 and improve rendering.
-``` +```text qat_flow_ckpts qat_flow_ckpts_1755708286--- `19-56`: **Specify fenced code language for the directory tree.** Add a language to satisfy MD040 and improve rendering. ```diff -``` +```text ├── 00_openscience_data │ ├── code ... │ └── configs
</blockquote></details> <details> <summary>examples/nemo_run/qat/README.md (5)</summary><blockquote> `5-8`: **Header links: remove duplicate or rename.** Both “Slurm Examples” and “Advanced Topics” point to ADVANCED.md. Either remove one or rename to distinct anchors. ```diff -[Slurm Examples](ADVANCED.md) | -[Advanced Topics](ADVANCED.md) | +[Slurm Examples](ADVANCED.md) |
Or create section anchors in ADVANCED.md and link each separately.
44-44
: Remove double space.-You can run the example either locally or on a [Slurm cluster](ADVANCED.md). +You can run the example either locally or on a [Slurm cluster](ADVANCED.md).
46-55
: Safer container mounting; avoid hardcoding dist-packages path.Mount repos and use editable installs or PYTHONPATH to avoid Python-version-specific paths.
-Example docker command: -``` -docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash -``` +Example docker command: +```bash +docker run --gpus all --shm-size 20g --rm -it \ + -v /home/user/NeMo:/opt/NeMo \ + -v /home/user/TensorRT-Model-Optimizer:/workspace/TRTMO \ + -v /home/user:/home/user \ + nvcr.io/nvidia/nemo:25.07 bash +``` +Then inside the container: +```bash +pip install -e /workspace/TRTMO # or: export PYTHONPATH=/workspace/TRTMO:$PYTHONPATH +```
63-69
: Clarify working directory.Specify that the command is run from examples/nemo_run.
-From the `nemo_run` folder, launch the example with the `qat/nemo_qat_flow.py` script. +From the `examples/nemo_run` folder, launch the example with the `qat/nemo_qat_flow.py` script.
96-96
: Grammar: add article.-To provide path to a custom chat template, use the `--chat-template <my_template.txt>` flag. +To provide a path to a custom chat template, use the `--chat-template <my_template.txt>` flag.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (3)
examples/llm_qat/README.md
(1 hunks)examples/nemo_run/qat/ADVANCED.md
(1 hunks)examples/nemo_run/qat/README.md
(2 hunks)
🧰 Additional context used
🪛 markdownlint-cli2 (0.17.2)
examples/nemo_run/qat/ADVANCED.md
13-13: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
19-19: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
🪛 LanguageTool
examples/nemo_run/qat/README.md
[grammar] ~5-~5: There might be a mistake here.
Context: ...Example Slurm Examples | Advanced Topics | [NeMo I...
(QB_NEW_EN)
[grammar] ~6-~6: There might be a mistake here.
Context: ...D.md) | Advanced Topics | [NeMo Integration](https://github.com/NV...
(QB_NEW_EN)
[grammar] ~22-~22: There might be a mistake here.
Context: ...odel checkpoint and evaluate 5% of MMLU on BF16 checkpoint 1. PTQ the model and ev...
(QB_NEW_EN)
[grammar] ~23-~23: There might be a mistake here.
Context: .... PTQ the model and evaluate 5% of MMLU on PTQ Checkpoint 1. SFT (finetune) the mo...
(QB_NEW_EN)
[grammar] ~26-~26: There might be a mistake here.
Context: ...ate 5% of MMLU on the SFT checkpoint 1. Export model to Unified checkpoint (HuggingFac...
(QB_NEW_EN)
[style] ~64-~64: Three successive sentences begin with the same word. Consider rewording the sentence or use a thesaurus to find a synonym.
Context: ...dia/OpenScience) dataset will be used. To perform QAT, run: ```bash python qat/n...
(ENGLISH_WORD_REPEAT_BEGINNING_RULE)
[grammar] ~90-~90: There might be a mistake here.
Context: ... following model: - Model: Qwen3-8B - Recipe: qwen3_8b ### Custom Chat Tem...
(QB_NEW_EN)
[grammar] ~96-~96: There might be a mistake here.
Context: ...ggingface/transformers/pull/30650)). To provide path to a custom chat template, use the...
(QB_NEW_EN)
[style] ~99-~99: Consider using a less common alternative to make your writing sound more unique and professional.
Context: ... recover the accuracy of the PTQ model. Feel free to try your own model and dataset combinat...
(FEEL_FREE_TO_STYLE_ME)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (2)
examples/nemo_run/qat/README.md (2)
28-37
: Mermaid flow looks good.
48-50
: Pin NeMo commit with full hash and date
In examples/nemo_run/qat/README.md (lines 48–50), replace the short hashddcb75f
with its full 40-character commit hash and include the commit date/message. You can retrieve these details by cloning the NeMo repo locally and running:git clone https://github.com/NVIDIA-NeMo/NeMo.git cd NeMo git rev-parse ddcb75f git show -s --format='%H %ad %s' ddcb75f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
examples/nemo_run/qat/README.md (2)
19-27
: Align the stage list with the 8-stage flow and exact task names.The list shows six generic steps while the code/diagram uses eight named stages. Please list all eight with exact IDs to match logs and ADVANCED.md.
Use:
-Currently the Simplified Flow runs the following steps in order: - -1. Process Nvidia/OpenScience data (if `--data-path` is not specified) -1. Import NeMo BF16 model checkpoint and evaluate 5% of MMLU on BF16 checkpoint -1. PTQ the model and evaluate 5% of MMLU on PTQ Checkpoint -1. SFT (finetune) the model -1. Evaluate 5% of MMLU on the SFT checkpoint -1. Export model to Unified checkpoint (HuggingFace) format in lower precision +Flow stages: + +1. 00_openscience_data — Process NVIDIA OpenScience data (skipped if `--data-path` is provided) +2. 01_import_model — Import NeMo BF16 model checkpoint +3. 02_mmlu_bf16 — Evaluate 5% MMLU on BF16 checkpoint +4. 03_ptq — Apply PTQ +5. 04_mmlu_ptq — Evaluate 5% MMLU on PTQ checkpoint +6. 05_train — SFT/QAT (and optional QAD) +7. 06_mmlu_sft — Evaluate 5% MMLU on SFT/QAT checkpoint +8. 07_export_hf — Export to Hugging Face (Unified) format
15-16
: Fix incomplete sentence after PTQ.This reads as a fragment; complete it to explain why QAT/QAD follow PTQ.
Apply:
-After PTQ (post-training quantization), the quantized model may +After PTQ (post-training quantization), the quantized model may exhibit accuracy degradation on tasks like MMLU; the subsequent QAT/QAD stages aim to recover that loss.
🧹 Nitpick comments (7)
examples/nemo_run/qat/README.md (7)
5-7
: Avoid duplicate links to the same target in the header.Both “Slurm Examples” and “Advanced Topics” point to ADVANCED.md. Either consolidate or point “Slurm Examples” to a section anchor.
If ADVANCED.md has a Slurm section, consider:
-[Slurm Examples](ADVANCED.md) | -[Advanced Topics](ADVANCED.md) | +[Slurm Examples](ADVANCED.md#slurm) | +[Advanced Topics](ADVANCED.md) |
43-43
: Fix minor spacing typo.-You can run the example either locally or on a [Slurm cluster](ADVANCED.md). +You can run the example either locally or on a [Slurm cluster](ADVANCED.md).
45-49
: Prefer PYTHONPATH over bind-mounting into site-packages.Mounting into /usr/local/.../site-packages can be brittle across images/python versions. Using PYTHONPATH keeps the container clean and reduces surprises.
-Example docker command: +Example docker command: @@ -``` -docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash -``` +```bash +docker run --gpus all --shm-size 20g --rm -it \ + -v /home/user/NeMo:/opt/NeMo \ + -v /home/user/TensorRT-Model-Optimizer:/workspace/TensorRT-Model-Optimizer \ + -e PYTHONPATH=/opt/NeMo:/workspace/TensorRT-Model-Optimizer \ + nvcr.io/nvidia/nemo:25.07 bash +```
52-55
: Add language to the fenced code block (markdownlint MD040).-``` +```bash docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash--- `62-62`: **Minor wording polish and consistent naming.** Use “Hugging Face” and ensure dataset name capitalization is consistent. ```diff -... use the model's HuggingFace name ... +... use the model's Hugging Face name ...
Also, earlier “Nvidia/OpenScience” → “NVIDIA OpenScience” (addressed in the stage list fix).
91-94
: Tighten wording and fix article usage.-By default the script will use the model/tokenizer's chat template, which may not contain the `{% generation %}` and `{% endgeneration %}` tags around the assistant tokens which are needed to generate the assistant loss mask (see [this PR](https://github.com/huggingface/transformers/pull/30650)). To provide path to a custom chat template, use the `--chat-template <my_template.txt>` flag. +By default, the script uses the model/tokenizer's chat template, which may not contain the `{% generation %}` and `{% endgeneration %}` tags around assistant tokens that are needed to generate the assistant loss mask (see [this PR](https://github.com/huggingface/transformers/pull/30650)). To provide a path to a custom chat template, use the `--chat-template <my_template.txt>` flag.
97-98
: Tone/style refinement.Optional rephrase to avoid “Feel free to” and tighten the message.
-The current QAT recipe has been tuned for the Qwen3-8B model to improve accuracy on the MMLU benchmark after PTQ degradation. QAT/QAD results are highly dependent on the specific model, dataset, and hyperparameters. There is no guarantee that the same dataset will recover the accuracy of the PTQ model. Feel free to try your own model and dataset combinations and test which combination works best. +The current QAT recipe is tuned for Qwen3-8B to improve MMLU after PTQ-induced degradation. QAT/QAD results depend on the model, dataset, and hyperparameters, and the same dataset may not recover PTQ accuracy. You are encouraged to try different model–dataset combinations and validate which works best in your environment.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
examples/nemo_run/common/process_openscience.py
(1 hunks)examples/nemo_run/qat/README.md
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/nemo_run/common/process_openscience.py
🧰 Additional context used
🪛 LanguageTool
examples/nemo_run/qat/README.md
[grammar] ~5-~5: There might be a mistake here.
Context: ...Example Slurm Examples | Advanced Topics | [NeMo I...
(QB_NEW_EN)
[grammar] ~6-~6: There might be a mistake here.
Context: ...D.md) | Advanced Topics | [NeMo Integration](https://github.com/NV...
(QB_NEW_EN)
[grammar] ~22-~22: There might be a mistake here.
Context: ...odel checkpoint and evaluate 5% of MMLU on BF16 checkpoint 1. PTQ the model and ev...
(QB_NEW_EN)
[grammar] ~23-~23: There might be a mistake here.
Context: .... PTQ the model and evaluate 5% of MMLU on PTQ Checkpoint 1. SFT (finetune) the mo...
(QB_NEW_EN)
[grammar] ~26-~26: There might be a mistake here.
Context: ...ate 5% of MMLU on the SFT checkpoint 1. Export model to Unified checkpoint (HuggingFac...
(QB_NEW_EN)
[style] ~63-~63: Three successive sentences begin with the same word. Consider rewording the sentence or use a thesaurus to find a synonym.
Context: ...dia/OpenScience) dataset will be used. To perform QAT, run: ```bash python qat/n...
(ENGLISH_WORD_REPEAT_BEGINNING_RULE)
[grammar] ~88-~88: There might be a mistake here.
Context: ... following model: - Model: Qwen3-8B - Recipe: qwen3_8b ### Custom Chat Temp...
(QB_NEW_EN)
[grammar] ~93-~93: There might be a mistake here.
Context: ...ggingface/transformers/pull/30650)). To provide path to a custom chat template, use the...
(QB_NEW_EN)
[style] ~97-~97: Consider using a less common alternative to make your writing sound more unique and professional.
Context: ... recover the accuracy of the PTQ model. Feel free to try your own model and dataset combinat...
(FEEL_FREE_TO_STYLE_ME)
🪛 markdownlint-cli2 (0.17.2)
examples/nemo_run/qat/README.md
52-52: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (4)
examples/nemo_run/qat/README.md (4)
66-68
: LGTM: local QAT invocation example is clear and minimal.
79-80
: LGTM: QAD invocation example is clear and mirrors QAT usage.
28-37
: No changes needed: diagram labels align with code stage IDs.
All Mermaid nodes in examples/nemo_run/qat/README.md correspond exactly to the stage names in nemo_qat_flow.py.
84-87
: Resolved: CLI flags and defaults match documentation
Flags--train-nodes
(default 1),--train-gpus
(default 8), and--ptq-gpus
(default 4) are correctly defined in nemo_qat_flow.py and align with the README.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (4)
examples/nemo_run/qat/ADVANCED.md (1)
5-5
: Don’t instruct users to hardcode credentials; point to env-driven config.Replace “edit SLURM_CONFIG in source” with guidance to set environment variables that SlurmConfig reads, and reference the utils.SlurmConfig schema. This avoids secrets in source and eases CI.
-To run the example on slurm, edit the `SLURM_CONFIG` at the bottom of `nemo_qat_flow.py` with the appropriate credentials, container, cluster name (host), and container mounts. +Define Slurm settings via environment variables consumed by `SlurmConfig` (see `examples/nemo_run/common/utils.py`) rather than editing source. Set values like `SLURM_ACCOUNT`, `SLURM_PARTITION_{GPU,CPU}`, `SLURM_TIME`, `CONTAINER_IMAGE`, `CONTAINER_MOUNTS`, `SLURM_HOST`, and `SLURM_USER` before launching.examples/nemo_run/qat/nemo_qat_flow.py (3)
146-153
: Separate “recipe name” from “model name” and validate before getattr.Avoid overloading
model_name
for both concepts; it risks attribute errors and wrong path prefixes.- model_name = args.finetune_recipe - model_module = getattr(llm, model_name) - if not model_name: - model_name = os.path.basename(args.model_name) + recipe_name = args.finetune_recipe + if not recipe_name: + raise ValueError("--finetune-recipe must be specified when --distill is not used") + model_module = getattr(llm, recipe_name) + # Use recipe name as the filesystem prefix for artifacts + model_name = recipe_name
138-142
: KV-cache flag is forced to “disabled” by default and CLI uses underscore not hyphen.Use a tri-state CLI: only pass a flag when explicitly set, and expose hyphenated flags to match README. This avoids overriding
ptq.py
defaults unintentionally and removes doc/code mismatch.- parser.add_argument( - "--enable_kv_cache", - help="Enables KV-cache quantization", - action="store_true", - default=False, - ) + # Tri-state KV-cache control; pass a flag only if explicitly set + kv = parser.add_mutually_exclusive_group(required=False) + parser.set_defaults(enable_kv_cache=None) + kv.add_argument("--enable-kv-cache", dest="enable_kv_cache", action="store_true", help="Enable KV-cache quantization") + kv.add_argument("--disable-kv-cache", dest="enable_kv_cache", action="store_false", help="Disable KV-cache quantization") @@ - ptq = run.Script( + # Build KV-cache flag only when explicitly set + kv_cache_flag = ( + ["--enable_kv_cache"] if args.enable_kv_cache is True + else (["--disable_kv_cache"] if args.enable_kv_cache is False else []) + ) + ptq = run.Script( @@ - "--kv_cache_qformat", - args.kv_cache_qformat, - "--enable_kv_cache" if args.enable_kv_cache else "--disable_kv_cache", + "--kv_cache_qformat", + args.kv_cache_qformat, + *kv_cache_flag,Also applies to: 176-194
266-275
: Don’t reuse a single Slurm executor for PTQ, Train, and Export.Reusing and mutating one executor risks wrong gres/gpu allocations across stages. Create dedicated executors per stage.
- if args.use_slurm: - cpu_executor = create_slurm_executor(SLURM_CONFIG) - gpu_executor = create_slurm_executor( - SLURM_CONFIG, num_gpus=args.ptq_gpus, ntasks_per_node=args.ptq_gpus - ) - single_gpu_executor = create_slurm_executor(SLURM_CONFIG, num_gpus=1, ntasks_per_node=1) + if args.use_slurm: + cpu_executor = create_slurm_executor(SLURM_CONFIG) + ptq_gpu_executor = create_slurm_executor( + SLURM_CONFIG, num_gpus=args.ptq_gpus, ntasks_per_node=args.ptq_gpus + ) + train_gpu_executor = create_slurm_executor( + SLURM_CONFIG, nodes=args.train_nodes, num_gpus=args.train_gpus, ntasks_per_node=args.train_gpus + ) + single_gpu_executor = create_slurm_executor(SLURM_CONFIG, num_gpus=1, ntasks_per_node=1) else: - cpu_executor = single_gpu_executor = run.LocalExecutor() - gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.ptq_gpus) + cpu_executor = single_gpu_executor = run.LocalExecutor() + ptq_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.ptq_gpus) + train_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.train_gpus) @@ - s2 = exp.add(ptq, tail_logs=True, name="03_ptq", executor=gpu_executor, dependencies=[s1]) + s2 = exp.add(ptq, tail_logs=True, name="03_ptq", executor=ptq_gpu_executor, dependencies=[s1]) @@ - if args.use_slurm: # Set training arguments - gpu_executor.nodes = args.train_nodes - gpu_executor.devices = gpu_executor.ntasks_per_node = args.train_gpus - else: - gpu_executor.ntasks_per_node = args.train_gpus + # training resources already encoded in train_gpu_executor @@ - s4 = exp.add( - train, tail_logs=True, name="05_train", executor=gpu_executor, dependencies=train_dep - ) + s4 = exp.add(train, tail_logs=True, name="05_train", executor=train_gpu_executor, dependencies=train_dep) @@ - gpu_executor.ntasks_per_node = 1 # will throw error if more than 1 task during export - exp.add( - export, - tail_logs=True, - name="07_export_hf", - executor=gpu_executor, - dependencies=[s5], - ) + exp.add(export, tail_logs=True, name="07_export_hf", executor=single_gpu_executor, dependencies=[s5])Also applies to: 293-301, 302-308, 311-314, 322-329
🧹 Nitpick comments (13)
examples/nemo_run/qat/ADVANCED.md (4)
3-3
: Polish the intro sentence (comma + “one”).Add a comma after “for example” and prefer “one” over “1” for docs consistency.
-If you need to run QAT/QAD on a Slurm cluster (for example to use more than 1 node), this guide covers how to configure and launch on Slurm. +If you need to run QAT/QAD on a Slurm cluster (for example, to use more than one node), this guide covers how to configure and launch on Slurm.
11-11
: Clarify how to get rsync inside the container.Add a one-liner to install rsync if missing to prevent launch failures.
-**NOTE:** `rsync` may not currently be available in the NeMo container and will be added as a dependency. +**NOTE:** If `rsync` is not available in the NeMo container, install it before launching: +```bash +apt-get update && apt-get install -y rsync +```
13-16
: Add a language to fenced code block (markdownlint MD040).-``` +```text qat_flow_ckpts qat_flow_ckpts_1755708286--- `19-56`: **Add a language to fenced code block (markdownlint MD040).** ```diff -``` +```text ├── 00_openscience_data │ ├── code │ ├── configs …
</blockquote></details> <details> <summary>examples/nemo_run/common/in_memory_mmlu.py (1)</summary><blockquote> `25-27`: **Align CLI help with actual flag name.** Help text says “--ckpt_dir” but the flag is “--finetuned_ckpt_dir”. Update description to avoid confusion. ```diff - description="Run MMLU evaluation with ModelOpt Megatron model. Provide either --nemo_ckpt or --ckpt_dir" + description="Run MMLU evaluation with ModelOpt Megatron model. Provide either --nemo_ckpt or --finetuned_ckpt_dir"
examples/nemo_run/qat/nemo_qat_flow.py (3)
29-31
: Avoid sys.path mutation; import via package.Prefer making
examples/nemo_run/common
a package and importingutils
normally.-sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "common"))) -from utils import SlurmConfig, create_slurm_executor, get_finetune_recipe, read_chat_template +from examples.nemo_run.common.utils import ( + SlurmConfig, + create_slurm_executor, + get_finetune_recipe, + read_chat_template, +)
239-239
: Temporary validation cap left on.
train.trainer.limit_val_batches = 2
looks like a debug setting. Remove or guard with a flag before merging.
361-366
: Module constants are defined after use; move to top-level for clarity.They’re set before main() runs, but relocating them near imports improves readability and reduces ordering hazards.
examples/nemo_run/qat/README.md (5)
45-45
: Fix double space and minor phrasing.-You can run the example either locally or on a [Slurm cluster](ADVANCED.md). +You can run the example either locally or on a [Slurm cluster](ADVANCED.md).
54-56
: Add language to the Docker command block (markdownlint MD040).-``` +```bash docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash--- `64-70`: **Consider a more robust local install flow for modelopt.** Mounting into `dist-packages` is brittle. Suggest editable installs. ```diff -Example docker command: +Example docker command (then install repos inside the container): @@ -```bash -docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash -``` +```bash +docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash +pip install -e /opt/NeMo +pip install -e /home/user/TensorRT-Model-Optimizer/modelopt +```
72-72
: CLI flag mismatch with code (--enable-kv-cache
vs--enable_kv_cache
).Docs use hyphenated flag; current code defines underscore. Recommend aligning to hyphenated form per argparse conventions, or update docs once code changes.
95-95
: Minor grammar.-To provide path to a custom chat template, use the `--chat-template <my_template.txt>` flag. +To provide a path to a custom chat template, use the `--chat-template <my_template.txt>` flag.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (5)
examples/nemo_run/common/in_memory_mmlu.py
(1 hunks)examples/nemo_run/qat/ADVANCED.md
(1 hunks)examples/nemo_run/qat/README.md
(2 hunks)examples/nemo_run/qat/nemo_qat_flow.py
(6 hunks)modelopt/torch/export/plugins/nemo_run.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/export/plugins/nemo_run.py
🧰 Additional context used
🧬 Code graph analysis (2)
examples/nemo_run/common/in_memory_mmlu.py (2)
modelopt/torch/export/plugins/nemo_run.py (1)
_get_most_recent_ckpt
(50-73)modelopt/torch/utils/plugins/megatron_mmlu.py (1)
megatron_mmlu
(65-152)
examples/nemo_run/qat/nemo_qat_flow.py (2)
modelopt/torch/export/plugins/nemo_run.py (1)
export_most_recent_ckpt
(24-35)examples/nemo_run/common/utils.py (4)
SlurmConfig
(24-72)create_slurm_executor
(75-117)get_finetune_recipe
(120-123)read_chat_template
(126-128)
🪛 LanguageTool
examples/nemo_run/qat/README.md
[grammar] ~5-~5: There might be a mistake here.
Context: ...Example Slurm Examples | Advanced Topics | [NeMo I...
(QB_NEW_EN)
[grammar] ~6-~6: There might be a mistake here.
Context: ...D.md) | Advanced Topics | [NeMo Integration](https://github.com/NV...
(QB_NEW_EN)
[style] ~65-~65: Three successive sentences begin with the same word. Consider rewording the sentence or use a thesaurus to find a synonym.
Context: ...dia/OpenScience) dataset will be used. To perform QAT, run: ```bash python qat/n...
(ENGLISH_WORD_REPEAT_BEGINNING_RULE)
[grammar] ~90-~90: There might be a mistake here.
Context: ... following model: - Model: Qwen3-8B - Recipe: qwen3_8b ### Custom Chat Temp...
(QB_NEW_EN)
[grammar] ~95-~95: There might be a mistake here.
Context: ...ggingface/transformers/pull/30650)). To provide path to a custom chat template, use the...
(QB_NEW_EN)
[style] ~99-~99: Consider using a less common alternative to make your writing sound more unique and professional.
Context: ... recover the accuracy of the PTQ model. Feel free to try your own model and dataset combinat...
(FEEL_FREE_TO_STYLE_ME)
🪛 markdownlint-cli2 (0.17.2)
examples/nemo_run/qat/README.md
54-54: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
examples/nemo_run/qat/ADVANCED.md
13-13: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
19-19: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (2)
examples/nemo_run/common/in_memory_mmlu.py (1)
28-35
: Good switch to a mutually exclusive group for checkpoint inputs.examples/nemo_run/qat/nemo_qat_flow.py (1)
336-357
: I’ve initiated a shell check to verifySlurmConfig
import and whetheros
is already imported innemo_qat_flow.py
.
train.trainer.max_steps = 200 | ||
train.log.log_dir = exp_dir | ||
train.trainer.val_check_interval = VAL_INTERVAL | ||
train.trainer.max_steps = TRAIN_STEPS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
train.trainer.max_steps = TRAIN_STEPS | |
train.trainer.max_steps = TRAIN_STEPS | |
train.trainer.strategy.tensor_model_parallel_size = args.train_gpus |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right now it's running DP by default which is what I want (since 8B model will fit in 1 GPU)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess could be helpful if someone has less gpu memory per device, but not 100% necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/nemo_run/qat/README.md (1)
74-75
: Wrong flag spelling: hyphen vs underscore.The code defines
--enable_kv_cache
and--kv-cache-qformat
; the README shows--enable-kv-cache
. This will 100% fail at runtime.Apply:
-> **_NOTE:_** To enable KV cache quantization, add `--enable-kv-cache` and specify qformat using `--kv-cache-qformat <fp8, nvfp4>`. +> **_NOTE:_** To enable KV cache quantization, add `--enable_kv_cache` and specify qformat using `--kv-cache-qformat <fp8, nvfp4>`.
♻️ Duplicate comments (3)
examples/nemo_run/common/in_memory_mmlu.py (1)
48-49
: Bug: unreachable finetuned path due to arg name mismatch.Branch checks args.ckpt_dir which doesn’t exist; it should use args.finetuned_ckpt_dir.
Apply:
- if args.ckpt_dir: - ckpt_path = _get_most_recent_ckpt(args.ckpt_dir) + if args.finetuned_ckpt_dir: + ckpt_path = _get_most_recent_ckpt(args.finetuned_ckpt_dir)examples/nemo_run/qat/nemo_qat_flow.py (2)
137-142
: Tri‑state KV‑cache flag handling; don’t force disable by default.Currently
--disable_kv_cache
is always passed when the user doesn’t specify anything, changing defaults implicitly.Apply:
@@ def get_args(): - parser.add_argument( - "--enable_kv_cache", - help="Enables KV-cache quantization", - action="store_true", - default=False, - ) + kv_group = parser.add_mutually_exclusive_group(required=False) + kv_group.add_argument( + "--enable_kv_cache", + dest="enable_kv_cache", + help="Enable KV-cache quantization", + action="store_true", + ) + kv_group.add_argument( + "--disable_kv_cache", + dest="enable_kv_cache", + help="Disable KV-cache quantization", + action="store_false", + ) + parser.set_defaults(enable_kv_cache=None) # tri-state: True/False/None @@ - ptq = run.Script( + # Build KV-cache flag only if explicitly set + kv_flag = ( + ["--enable_kv_cache"] if args.enable_kv_cache is True + else (["--disable_kv_cache"] if args.enable_kv_cache is False else []) + ) + ptq = run.Script( "/opt/NeMo/scripts/llm/ptq.py", args=[ @@ - "--kv_cache_qformat", - args.kv_cache_qformat, - "--enable_kv_cache" if args.enable_kv_cache else "--disable_kv_cache", + "--kv_cache_qformat", + args.kv_cache_qformat, + *kv_flag,Also applies to: 183-200
274-283
: Do not reuse one GPU Slurm/Local executor for PTQ, Train, and Export.Mutating
gpu_executor
(nodes/ntasks_per_node/devices) is error‑prone; requests can leak across stages.Apply:
@@ - if args.use_slurm: - cpu_executor = create_slurm_executor(SLURM_CONFIG) - gpu_executor = create_slurm_executor( - SLURM_CONFIG, num_gpus=args.ptq_gpus, ntasks_per_node=args.ptq_gpus - ) - single_gpu_executor = create_slurm_executor(SLURM_CONFIG, num_gpus=1, ntasks_per_node=1) + if args.use_slurm: + cpu_executor = create_slurm_executor(SLURM_CONFIG) + ptq_gpu_executor = create_slurm_executor( + SLURM_CONFIG, num_gpus=args.ptq_gpus, ntasks_per_node=args.ptq_gpus + ) + train_gpu_executor = create_slurm_executor( + SLURM_CONFIG, nodes=args.train_nodes, num_gpus=args.train_gpus, ntasks_per_node=args.train_gpus + ) + single_gpu_executor = create_slurm_executor(SLURM_CONFIG, num_gpus=1, ntasks_per_node=1) else: - cpu_executor = single_gpu_executor = run.LocalExecutor() - gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.ptq_gpus) + cpu_executor = single_gpu_executor = run.LocalExecutor() + ptq_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.ptq_gpus) + train_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.train_gpus) @@ - s2 = exp.add(ptq, tail_logs=True, name="03_ptq", executor=gpu_executor, dependencies=[s1]) + s2 = exp.add(ptq, tail_logs=True, name="03_ptq", executor=ptq_gpu_executor, dependencies=[s1]) @@ - if args.use_slurm: # Set training arguments - gpu_executor.nodes = args.train_nodes - gpu_executor.devices = gpu_executor.ntasks_per_node = args.train_gpus - else: - gpu_executor.ntasks_per_node = args.train_gpus train_dep = [s3] if not args.data_path: train_dep.append(s0) - s4 = exp.add( - train, tail_logs=True, name="05_train", executor=gpu_executor, dependencies=train_dep - ) + s4 = exp.add(train, tail_logs=True, name="05_train", executor=train_gpu_executor, dependencies=train_dep) @@ - gpu_executor.ntasks_per_node = 1 # will throw error if more than 1 task during export - exp.add( - export, - tail_logs=True, - name="07_export_hf", - executor=gpu_executor, - dependencies=[s5], - ) + exp.add(export, tail_logs=True, name="07_export_hf", executor=single_gpu_executor, dependencies=[s5])Also applies to: 301-336
🧹 Nitpick comments (10)
examples/nemo_run/common/in_memory_mmlu.py (2)
26-26
: Fix help text to match the actual flag name.Says “--ckpt_dir” but the CLI exposes “--finetuned_ckpt_dir”.
Apply:
- description="Run MMLU evaluation with ModelOpt Megatron model. Provide either --nemo_ckpt or --ckpt_dir" + description="Run MMLU evaluation with ModelOpt Megatron model. Provide either --nemo_ckpt or --finetuned_ckpt_dir"
47-55
: Optional: validate resolved checkpoint path early.A quick exists-check gives clearer error messages before restoration.
Example:
ckpt_path = args.nemo_ckpt if args.finetuned_ckpt_dir: ckpt_path = _get_most_recent_ckpt(args.finetuned_ckpt_dir) + if not ckpt_path or not os.path.exists(ckpt_path): + raise FileNotFoundError(f"Checkpoint path not found: {ckpt_path}")(Remember to
import os
.)examples/nemo_run/qat/README.md (4)
45-45
: Minor: stray double space.-You can run the example either locally or on a [Slurm cluster](ADVANCED.md). +You can run the example either locally or on a [Slurm cluster](ADVANCED.md).
97-97
: Grammar nit: add article.- To provide path to a custom chat template, use the `--chat-template <my_template.txt>` flag. + To provide a path to a custom chat template, use the `--chat-template <my_template.txt>` flag.
55-56
: Fragile mount path to site-packages. Prefer editable install or compute path.Hardcoding
/usr/local/lib/python3.12/dist-packages/modelopt
is container/Python-version dependent.Consider replacing with one of:
- Use editable install inside container:
pip install -e /home/user/TensorRT-Model-Optimizer/modelopt
.- Or compute site-packages path dynamically:
python - <<'PY' import site, sys, os print(next(p for p in site.getsitepackages() if 'site-packages' in p)) PYThen mount to that returned path.
Would you like me to propose a concise Docker snippet reflecting this?
5-7
: Duplicate links to the same target.“Slurm Examples” and “Advanced Topics” both point to ADVANCED.md. If intentional, consider distinct anchors; otherwise dedupe.
examples/nemo_run/qat/nemo_qat_flow.py (4)
146-153
: Separate “recipe name” from “model name/path prefix” to avoid overload.Using
model_name = args.finetune_recipe
mixes two concepts and complicates fallbacks.Apply:
def main(args): if not args.distill and not args.finetune_recipe: raise ValueError("If distillation is not used, --finetune-recipe must be specified") - model_name = args.finetune_recipe - model_module = getattr(llm, model_name) - if not model_name: - model_name = os.path.basename(args.model_name) + recipe_name = args.finetune_recipe + model_module = getattr(llm, recipe_name) + # Use recipe name as path prefix; keep HF model name separate + path_prefix = recipe_name @@ - bf16_ckpt_path = f"{exp_dir}/{model_name}-nemo" + bf16_ckpt_path = f"{exp_dir}/{path_prefix}-nemo" @@ - ptq_model_out = f"{exp_dir}/{model_name}-{args.algorithm}" + ptq_model_out = f"{exp_dir}/{path_prefix}-{args.algorithm}"Also update the export path at Line 249 accordingly (see next comment).
Also applies to: 171-182
248-250
: Keep export path consistent with path prefix refactor.- export = run.Partial( - export_most_recent_ckpt, train.log.log_dir, output_path=f"{exp_dir}/{model_name}_hf" - ) + export = run.Partial( + export_most_recent_ckpt, train.log.log_dir, output_path=f"{exp_dir}/{path_prefix}_hf" + )
351-352
: Slurm time format should be HH:MM:SS.
"240"
is ambiguous and contradicts SlurmConfig’s guidance.- time="240", + time="04:00:00",
338-338
: Detach only when using Slurm; keep local runs blocking.Improves local UX and log visibility.
- exp.run(detach=True) + exp.run(detach=args.use_slurm)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/nemo_run/common/in_memory_mmlu.py
(1 hunks)examples/nemo_run/qat/README.md
(2 hunks)examples/nemo_run/qat/nemo_qat_flow.py
(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/nemo_run/qat/nemo_qat_flow.py (3)
modelopt/torch/export/plugins/nemo_run.py (1)
export_most_recent_ckpt
(24-35)examples/nemo_run/common/utils.py (4)
SlurmConfig
(24-72)create_slurm_executor
(75-117)get_finetune_recipe
(120-123)read_chat_template
(126-128)examples/nemo_run/common/in_memory_mmlu.py (1)
parse_args
(24-42)
examples/nemo_run/common/in_memory_mmlu.py (2)
modelopt/torch/export/plugins/nemo_run.py (1)
_get_most_recent_ckpt
(50-73)modelopt/torch/utils/plugins/megatron_mmlu.py (1)
megatron_mmlu
(65-152)
🪛 LanguageTool
examples/nemo_run/qat/README.md
[grammar] ~5-~5: There might be a mistake here.
Context: ...Example Slurm Examples | Advanced Topics | [NeMo I...
(QB_NEW_EN)
[grammar] ~6-~6: There might be a mistake here.
Context: ...D.md) | Advanced Topics | [NeMo Integration](https://github.com/NV...
(QB_NEW_EN)
[style] ~67-~67: Three successive sentences begin with the same word. Consider rewording the sentence or use a thesaurus to find a synonym.
Context: ...dia/OpenScience) dataset will be used. To perform QAT, run: ```bash python qat/n...
(ENGLISH_WORD_REPEAT_BEGINNING_RULE)
[grammar] ~92-~92: There might be a mistake here.
Context: ... following model: - Model: Qwen3-8B - Recipe: qwen3_8b ### Custom Chat Temp...
(QB_NEW_EN)
[grammar] ~97-~97: There might be a mistake here.
Context: ...ggingface/transformers/pull/30650)). To provide path to a custom chat template, use the...
(QB_NEW_EN)
[style] ~101-~101: Consider using a less common alternative to make your writing sound more unique and professional.
Context: ... recover the accuracy of the PTQ model. Feel free to try your own model and dataset combinat...
(FEEL_FREE_TO_STYLE_ME)
🔇 Additional comments (1)
examples/nemo_run/common/in_memory_mmlu.py (1)
28-35
: Good: proper CLI exclusivity for checkpoint inputs.Using argparse’s mutually exclusive group is the right fix; this prevents ambiguous invocations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (10)
modelopt/torch/export/plugins/nemo_run.py (1)
59-71
: Good fix: explicit exceptions instead of asserts.Replacing asserts with FileNotFoundError makes the checks reliable at runtime.
examples/nemo_run/common/in_memory_mmlu.py (1)
31-38
: Nice: proper argparse mutually-exclusive group.This removes the runtime assert and gives standard CLI errors.
examples/nemo_run/common/process_openscience.py (1)
20-21
: Bug: loading on-disk dataset with load_dataset will fail; use load_from_disk.You save with
save_to_disk
and must load withload_from_disk
.-from datasets import load_dataset +from datasets import load_dataset, load_from_disk @@ - ds = load_dataset(raw_dir) + ds = load_from_disk(raw_dir)Also applies to: 38-44
examples/nemo_run/common/utils.py (1)
120-124
: Guard against missing recipe module to avoid AttributeError.Accessing
getattr(llm, recipe_name)
unguarded can raise.-def get_finetune_recipe(recipe_name: str): - if not hasattr(getattr(llm, recipe_name), "finetune_recipe"): - raise ValueError(f"Recipe {recipe_name} does not have a Fine-Tuning recipe") - return getattr(llm, recipe_name).finetune_recipe(peft_scheme=None) +def get_finetune_recipe(recipe_name: str): + recipe_mod = getattr(llm, recipe_name, None) + if recipe_mod is None or not hasattr(recipe_mod, "finetune_recipe"): + raise ValueError(f"Recipe {recipe_name} does not exist or lacks a fine-tuning recipe") + return recipe_mod.finetune_recipe(peft_scheme=None)examples/nemo_run/qat/nemo_qat_flow.py (3)
137-143
: Fix KV-cache tri-state; don’t force disable by default.Currently,
--disable_kv_cache
is always passed unless--enable_kv_cache
is given. Make it tri-state and only append a flag when explicitly set.Apply:
@@ - parser.add_argument( - "--enable_kv_cache", - help="Enables KV-cache quantization", - action="store_true", - default=False, - ) + kv_group = parser.add_mutually_exclusive_group() + kv_group.add_argument( + "--enable_kv_cache", + dest="enable_kv_cache", + help="Enable KV-cache quantization", + action="store_true", + ) + kv_group.add_argument( + "--disable_kv_cache", + dest="enable_kv_cache", + help="Disable KV-cache quantization", + action="store_false", + ) + parser.set_defaults(enable_kv_cache=None) @@ - ptq = run.Script( + # Build KV-cache flag only when explicitly set + kv_cache_flag = ( + ["--enable_kv_cache"] if args.enable_kv_cache is True + else (["--disable_kv_cache"] if args.enable_kv_cache is False else []) + ) + ptq = run.Script( @@ - "--kv_cache_qformat", - args.kv_cache_qformat, - "--enable_kv_cache" if args.enable_kv_cache else "--disable_kv_cache", + "--kv_cache_qformat", + args.kv_cache_qformat, + *kv_cache_flag,Also applies to: 183-201
146-153
: Separate “recipe name” from “filesystem prefix”; avoid getattr before validation.Avoid overloading
model_name
for two concepts; userecipe_name
for lookup andckpt_prefix
for paths.def main(args): - if not args.distill and not args.finetune_recipe: + if not args.distill and not args.finetune_recipe: raise ValueError("If distillation is not used, --finetune-recipe must be specified") - model_name = args.finetune_recipe - model_module = getattr(llm, model_name) - if not model_name: - model_name = os.path.basename(args.model_name) + recipe_name = args.finetune_recipe + model_module = getattr(llm, recipe_name) + ckpt_prefix = recipe_name or os.path.basename(args.model_name) @@ - bf16_ckpt_path = f"{exp_dir}/{model_name}-nemo" + bf16_ckpt_path = f"{exp_dir}/{ckpt_prefix}-nemo" @@ - ptq_model_out = f"{exp_dir}/{model_name}-{args.algorithm}" + ptq_model_out = f"{exp_dir}/{ckpt_prefix}-{args.algorithm}" @@ - export = run.Partial( - export_most_recent_ckpt, train.log.log_dir, output_path=f"{exp_dir}/{model_name}_hf" - ) + export = run.Partial( + export_most_recent_ckpt, train.log.log_dir, output_path=f"{exp_dir}/{ckpt_prefix}_hf" + )Also applies to: 171-173, 181-182, 248-251
274-283
: Use dedicated executors for PTQ, Train, and Export; don’t mutate a single SlurmExecutor.Reusing and mutating
gpu_executor
can request the wrong GPU count for later stages (and some fields aren’t honored after init). Create separate executors and usesingle_gpu_executor
for export.@@ - if args.use_slurm: - cpu_executor = create_slurm_executor(SLURM_CONFIG) - gpu_executor = create_slurm_executor( - SLURM_CONFIG, num_gpus=args.ptq_gpus, ntasks_per_node=args.ptq_gpus - ) - single_gpu_executor = create_slurm_executor(SLURM_CONFIG, num_gpus=1, ntasks_per_node=1) + if args.use_slurm: + cpu_executor = create_slurm_executor(SLURM_CONFIG) + ptq_gpu_executor = create_slurm_executor( + SLURM_CONFIG, num_gpus=args.ptq_gpus, ntasks_per_node=args.ptq_gpus + ) + train_gpu_executor = create_slurm_executor( + SLURM_CONFIG, nodes=args.train_nodes, num_gpus=args.train_gpus, ntasks_per_node=args.train_gpus + ) + single_gpu_executor = create_slurm_executor(SLURM_CONFIG, num_gpus=1, ntasks_per_node=1) else: - cpu_executor = single_gpu_executor = run.LocalExecutor() - gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.ptq_gpus) + cpu_executor = single_gpu_executor = run.LocalExecutor() + ptq_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.ptq_gpus) + train_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.train_gpus) @@ - s2 = exp.add(ptq, tail_logs=True, name="03_ptq", executor=gpu_executor, dependencies=[s1]) + s2 = exp.add(ptq, tail_logs=True, name="03_ptq", executor=ptq_gpu_executor, dependencies=[s1]) @@ - if args.use_slurm: # Set training arguments - gpu_executor.nodes = args.train_nodes - gpu_executor.gpus_per_node = gpu_executor.ntasks_per_node = args.train_gpus - else: - gpu_executor.ntasks_per_node = args.train_gpus + # Use dedicated training executor; no mutation required train_dep = [s3] @@ - s4 = exp.add( - train, tail_logs=True, name="05_train", executor=gpu_executor, dependencies=train_dep - ) + s4 = exp.add( + train, tail_logs=True, name="05_train", executor=train_gpu_executor, dependencies=train_dep + ) @@ - gpu_executor.ntasks_per_node = 1 # will throw error if more than 1 task during export - exp.add( + exp.add( export, tail_logs=True, name="07_export_hf", - executor=gpu_executor, + executor=single_gpu_executor, dependencies=[s5], )Also applies to: 301-303, 311-316, 320-321, 330-336
examples/nemo_run/qat/ADVANCED.md (1)
5-5
: Avoid instructing users to hardcode credentials; prefer env/config-driven SlurmConfig.Recommend documenting required env vars or a YAML config and loading it in
nemo_qat_flow.py
instead of editing source.-To run the example on slurm, edit the `SLURM_CONFIG` at the bottom of `nemo_qat_flow.py` with the appropriate credentials, container, cluster name (host), and container mounts. Make sure you are mounting the NeMo and Megatron-LM repositories above in the Slurm cluster and that you've checked out the correct commits. +To run on Slurm, set the required environment variables (e.g., SLURM_ACCOUNT, SLURM_PARTITION_CPU/GPU, SLURM_TIME, SLURM_HOST, SLURM_USER, SLURM_JOB_DIR, CONTAINER_IMAGE, CONTAINER_MOUNTS) and construct `SLURM_CONFIG` from them in `nemo_qat_flow.py`. Avoid committing secrets or editing source with credentials.If you want, I can send a follow-up PR wiring env-var loading.
examples/nemo_run/qat/README.md (2)
86-94
: Move Slurm flag details to ADVANCED.md; keep this section model-focused.
This repeats Slurm specifics already covered elsewhere.-Locally this script currently supports models that can be trained on 1 node with 8 x 80GB GPUs. On Slurm you can configure the number of nodes/gpus for training and PTQ with the following flags: `--train-nodes`, `--train-gpus`, `--ptq-gpus`. +Locally this script currently supports models that can be trained on 1 node with 8 x 80GB GPUs. For Slurm configuration (nodes/GPUs for training and PTQ), see [Slurm Examples](ADVANCED.md).#!/bin/bash # Verify flags mentioned in docs exist in the CLI. rg -nP -C2 '(--train-nodes|--train-gpus|--ptq-gpus|--enable-kv-cache|--kv-cache-qformat|--distill)\b' examples -S
41-41
: Consider moving “Usage” right after “Overview”.
Improves discoverability; aligns with prior feedback.
🧹 Nitpick comments (18)
modelopt/torch/export/plugins/nemo_run.py (3)
44-47
: Make most-recent selection deterministic.Tie on mtime yields nondeterministic picks. Add name as a secondary key.
- most_recent = max(subdirs, key=lambda x: x.stat().st_mtime) + most_recent = max(subdirs, key=lambda x: (x.stat().st_mtime, x.name))
24-35
: Add return annotation and clarify API intent.Small polish: annotate return type and note that input expects a NeMo Run experiment root containing a default/checkpoints hierarchy.
-def export_most_recent_ckpt(directory: str, output_path: str): +def export_most_recent_ckpt(directory: str, output_path: str) -> None: """Export most recent checkpoint from a NeMo Run experiment directory. + + `directory` should be the experiment root containing `default/` with + either `default/checkpoints/*` or `default/<run>/checkpoints/*`. """
50-53
: Avoid importing a private helper across modules; promote to public API.This module’s
_get_most_recent_ckpt
is imported from another file. Either rename to a public helper or re-export a public alias.-def _get_most_recent_ckpt(directory: str): +def get_most_recent_ckpt(directory: str): @@ - return str(most_recent) + return str(most_recent)Then adjust imports in the caller to use
get_most_recent_ckpt
. If you prefer to keep the private name, addget_most_recent_ckpt = _get_most_recent_ckpt
at module end.Also applies to: 59-73
examples/nemo_run/common/in_memory_mmlu.py (3)
26-29
: Fix spacing in help description.The concatenated strings currently render “…--nemo_ckptor…”. Add a space.
- "Run MMLU evaluation with ModelOpt Megatron model. Provide either --nemo_ckpt" - "or --finetuned_ckpt_dir" + "Run MMLU evaluation with ModelOpt Megatron model. Provide either --nemo_ckpt " + "or --finetuned_ckpt_dir"
50-53
: Don’t import/use a private helper; switch to public API.Use
get_most_recent_ckpt
(see plugin comment) instead of_get_most_recent_ckpt
.-from modelopt.torch.export.plugins.nemo_run import _get_most_recent_ckpt +from modelopt.torch.export.plugins.nemo_run import get_most_recent_ckpt @@ - if args.finetuned_ckpt_dir: - ckpt_path = _get_most_recent_ckpt(args.finetuned_ckpt_dir) + if args.finetuned_ckpt_dir: + ckpt_path = get_most_recent_ckpt(args.finetuned_ckpt_dir)
50-52
: Log chosen checkpoint for traceability.Helpful when resolving “wrong model evaluated” issues.
ckpt_path = args.nemo_ckpt if args.finetuned_ckpt_dir: ckpt_path = get_most_recent_ckpt(args.finetuned_ckpt_dir) + print(f"Evaluating checkpoint: {ckpt_path}")
examples/nemo_run/common/process_openscience.py (1)
57-59
: Create parent dirs when making proc_dir.Avoid failures if parent path is missing.
- Path(proc_dir).mkdir(exist_ok=True) + Path(proc_dir).mkdir(parents=True, exist_ok=True)examples/nemo_run/common/utils.py (2)
107-117
: Apply env vars and consider ntasks_per_node for CPU branch.Parity with GPU path and user expectations.
- return run.SlurmExecutor( + return run.SlurmExecutor( account=slurm_cfg.account, partition=slurm_cfg.partition_cpu, + ntasks_per_node=ntasks_per_node, nodes=nodes, tunnel=tunnel, container_image=slurm_cfg.container_image, container_mounts=slurm_cfg.container_mounts, time=slurm_cfg.time, packager=run.GitArchivePackager(), mem="0", + env_vars=slurm_cfg.env_vars, # verify exact param name )
79-82
: Ensure LocalTunnel has a usable job_dir.If
use_local_tunnel=True
andjob_dir
is empty, LocalTunnel may fail.- tunnel = run.LocalTunnel(job_dir=slurm_cfg.job_dir) + tunnel = run.LocalTunnel(job_dir=slurm_cfg.job_dir or ".")Alternatively, validate
job_dir
for both tunnel modes in__post_init__
.examples/nemo_run/qat/nemo_qat_flow.py (2)
155-169
: Remove stale TODO and clarify path handling.Comments reference
common/process.py
, but the code usesprocess_openscience.py
. Clean up to avoid confusion.- # TODO figure out path - # LOCALLY common/process.py works - # On slurm examples/nemo_run/common/process.py works
118-129
: Naming clarity: consider--train-gpus-per-node
(and possibly--ptq-gpus-per-node
).Avoid ambiguity between total GPUs vs per-node. Aligns with Slurm executor semantics.
examples/nemo_run/qat/ADVANCED.md (1)
11-16
: Add fenced code languages for markdownlint (MD040).Annotate blocks as
text
to satisfy linters.-``` +```text qat_flow_ckpts qat_flow_ckpts_1755708286@@
-+
text
├── 00_openscience_data
│ ├── code
│ ├── configs
│ ├── log-coreai_dlalgo_modelopt-modelopt.00_openscience_data_5345664_0.out
│ └── sbatch_coreai_dlalgo_modelopt-modelopt.00_openscience_data_5345664.out
...Also applies to: 19-56
examples/llm_qat/README.md (1)
14-15
: LGTM — helpful cross-link to NeMo QAT/QAD flow.Consider adding a Docs link when available.
examples/nemo_run/qat/README.md (5)
5-7
: Deduplicate header links to ADVANCED.md.
Both “Slurm Examples” and “Advanced Topics” point to the same file; keep one.-[Slurm Examples](ADVANCED.md) | -[Advanced Topics](ADVANCED.md) | +[Advanced Topics (incl. Slurm)](ADVANCED.md) |
45-45
: Fix double space.-You can run the example either locally or on a [Slurm cluster](ADVANCED.md). +You can run the example either locally or on a [Slurm cluster](ADVANCED.md).
49-56
: Avoid bind-mounting into site-packages; prefer PYTHONPATH.
Mounting into /usr/local/lib/python3.12/dist-packages is version-specific and brittle. Use PYTHONPATH to point at local checkouts (orpip install -e
inside the container).-Example docker command: +Example docker command (avoid writing into site-packages; use PYTHONPATH to point to your local checkouts): @@ -docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash +docker run --gpus all -it --shm-size 20g --rm \ + -v /home/user/NeMo:/opt/NeMo \ + -v /home/user/TensorRT-Model-Optimizer:/workspace/TensorRT-Model-Optimizer \ + -e PYTHONPATH=/opt/NeMo:/workspace/TensorRT-Model-Optimizer \ + nvcr.io/nvidia/nemo:25.07 bash
58-59
: Branding and permissions guidance.
Use “Hugging Face” and avoidchmod 777
; point users to a writable log dir instead.-You will also need to set your Huggingface token with `export HF_TOKEN=<your-token>`. You may also need to enable write access to the docker container to the `examples/nemo_run` folder by doing `chmod 777 nemo_run` so that logs can be written. +Set your Hugging Face token with `export HF_TOKEN=<your-token>`. Ensure your `--log-dir` is a writable path you own (e.g., `mkdir -p /my/log/dir`), rather than using `chmod 777`.
66-66
: Minor wording/brand capitalization and punctuation.-From the `nemo_run` folder, launch the example with the `qat/nemo_qat_flow.py` script. To use a different model than the default model (Qwen3-8B), you can add the `--model-name <hf-model-name> --finetune-recipe <recipe-name>` flags and use the model's HuggingFace name and NeMo recipe names listed [here](https://github.com/NVIDIA/NeMo/tree/main/nemo/collections/llm/recipes). To provide your own custom dataset, use the `--data-path` flag, otherwise the default [NVIDIA OpenScience](https://huggingface.co/datasets/nvidia/OpenScience) dataset will be used. +From the `nemo_run` folder, launch the example with the `qat/nemo_qat_flow.py` script. To use a different model than the default model (Qwen3-8B), you can add the `--model-name <hf-model-name> --finetune-recipe <recipe-name>` flags and use the model's Hugging Face name and NeMo recipe names listed [here](https://github.com/NVIDIA/NeMo/tree/main/nemo/collections/llm/recipes). To provide your own custom dataset, use the `--data-path` flag; otherwise, the default [NVIDIA OpenScience](https://huggingface.co/datasets/nvidia/OpenScience) dataset will be used.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
examples/llm_qat/README.md
(1 hunks)examples/nemo_run/common/in_memory_mmlu.py
(1 hunks)examples/nemo_run/common/process_openscience.py
(1 hunks)examples/nemo_run/common/utils.py
(1 hunks)examples/nemo_run/qat/ADVANCED.md
(1 hunks)examples/nemo_run/qat/README.md
(2 hunks)examples/nemo_run/qat/nemo_qat_flow.py
(6 hunks)modelopt/torch/export/plugins/nemo_run.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
examples/nemo_run/common/in_memory_mmlu.py (2)
modelopt/torch/export/plugins/nemo_run.py (1)
_get_most_recent_ckpt
(50-73)modelopt/torch/utils/plugins/megatron_mmlu.py (1)
megatron_mmlu
(65-152)
examples/nemo_run/common/utils.py (1)
examples/nemo_run/common/process_lima.py (1)
download_hf_dataset
(29-37)
examples/nemo_run/qat/nemo_qat_flow.py (3)
modelopt/torch/export/plugins/nemo_run.py (1)
export_most_recent_ckpt
(24-35)examples/nemo_run/common/utils.py (4)
SlurmConfig
(24-72)create_slurm_executor
(75-117)get_finetune_recipe
(120-123)read_chat_template
(126-128)examples/nemo_run/common/in_memory_mmlu.py (1)
parse_args
(24-45)
🪛 markdownlint-cli2 (0.17.2)
examples/nemo_run/qat/ADVANCED.md
13-13: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
19-19: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (7)
examples/nemo_run/common/in_memory_mmlu.py (1)
57-58
: Good: devices computed from TP x PP.Matches the prior suggestion to wire devices to parallelism settings.
examples/nemo_run/common/utils.py (1)
91-106
: Verify SlurmExecutor accepts anenv_vars
parameter
We weren’t able to locate or importSlurmExecutor
in this repo or vianemo_run
—please confirm (via your installed package or docs) that its constructor supports anenv_vars
argument before merging.examples/nemo_run/qat/nemo_qat_flow.py (1)
183-201
: Confirmedptq.py
CLI arguments:-nc
maps to--nemo_checkpoint
,-out
to the export/save path, and-ctp
to the calibration tensor-parallel size (torchrun’s nproc_per_node must match this). No changes needed.examples/nemo_run/qat/ADVANCED.md (1)
11-12
: Verify rsync availability in the referenced container/tag.If missing in
nvcr.io/nvidia/nemo:25.07
, add install steps or suggest an alternative sync method.examples/nemo_run/qat/README.md (3)
21-29
: Flow stages list looks accurate and uses exact task names.
Matches the 8-stage design and prior feedback.
30-39
: Mermaid diagram aligns with the listed stages.
Edges reflect the intended data/control flow.
15-16
: Good fix on the PTQ sentence.
Clear explanation of expected accuracy drop and QAT/QAD intent.
def main(args): | ||
if not args.distill and not args.finetune_recipe: | ||
raise ValueError("If distillation is not used, --finetune-recipe must be specified") | ||
model_name = args.finetune_recipe | ||
model_module = getattr(llm, model_name) | ||
if not model_name: | ||
model_name = os.path.basename(args.model_name) | ||
exp_dir = f"{args.log_dir.rstrip('/')}/{args.experiment}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Guard globals used by main() to support programmatic invocation.
SEQUENCE_LENGTH
, MBS
, GBS
, TRAIN_STEPS
, VAL_INTERVAL
are set under __main__
but read inside main()
. If imported and called directly, this raises NameError. Add fallbacks in main()
.
def main(args):
+ # Fallbacks if module constants aren’t initialized (e.g., when imported)
+ global SEQUENCE_LENGTH, MBS, GBS, TRAIN_STEPS, VAL_INTERVAL
+ SEQUENCE_LENGTH = globals().get("SEQUENCE_LENGTH", 4096)
+ MBS = globals().get("MBS", 1)
+ GBS = globals().get("GBS", 512)
+ TRAIN_STEPS = globals().get("TRAIN_STEPS", 200)
+ VAL_INTERVAL = globals().get("VAL_INTERVAL", 50)
Optionally, hoist the constants to module scope instead of defining under __main__
.
Also applies to: 220-231, 368-376
🤖 Prompt for AI Agents
In examples/nemo_run/qat/nemo_qat_flow.py around lines 146-153 (also apply same
change to blocks at 220-231 and 368-376), main() reads globals SEQUENCE_LENGTH,
MBS, GBS, TRAIN_STEPS, VAL_INTERVAL which are only defined under __main__,
causing NameError when the module is imported and main() is called
programmatically; add explicit fallbacks inside main() by reading each with
globals().get("NAME", <sensible_default>) or os.environ fallback (e.g.,
SEQUENCE_LENGTH = globals().get("SEQUENCE_LENGTH", 2048) etc.), or better
move/hoist the constant definitions from the __main__ block to module scope so
they are always defined before main() is used; update all referenced locations
(lines listed) consistently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (4)
examples/nemo_run/qat/nemo_qat_flow.py (4)
351-351
: Slurm time format fixed. LGTM.The HH:MM:SS format resolves SlurmConfig validation errors.
138-143
: Fix KV-cache flag tri-state; don’t force disable by default.When the user passes neither flag, we should omit both PTQ flags. Currently,
--disable_kv_cache
is always injected becauseenable_kv_cache
defaults to False. Build the flag list conditionally and parse a tri-state.@@ - parser.add_argument( - "--enable_kv_cache", - help="Enables KV-cache quantization", - action="store_true", - default=False, - ) + kv_group = parser.add_mutually_exclusive_group(required=False) + kv_group.add_argument( + "--enable_kv_cache", + dest="enable_kv_cache", + action="store_true", + help="Enables KV-cache quantization", + ) + kv_group.add_argument( + "--disable_kv_cache", + dest="enable_kv_cache", + action="store_false", + help="Disables KV-cache quantization", + ) + parser.set_defaults(enable_kv_cache=None) @@ - ptq = run.Script( + # Build KV-cache flag only when explicitly set + kv_cache_flag = ( + ["--enable_kv_cache"] if args.enable_kv_cache is True + else (["--disable_kv_cache"] if args.enable_kv_cache is False else []) + ) + ptq = run.Script( "/opt/NeMo/scripts/llm/ptq.py", args=[ "-nc", bf16_ckpt_path, "-out", ptq_model_out, "--export_format", "nemo", "--algorithm", args.algorithm, "--kv_cache_qformat", args.kv_cache_qformat, - "--enable_kv_cache" if args.enable_kv_cache else "--disable_kv_cache", + *kv_cache_flag, "-ctp", f"{args.ptq_gpus}", ], entrypoint="python", )Also applies to: 183-199
274-283
: Use separate executors for PTQ, Train, and Export; stop mutating a shared Slurm executor.Mutating
gpu_executor
risks stalegres/gpus_per_node
in Slurm and incorrect local torchrun settings. Create dedicated executors and use the single-GPU executor for export.@@ - if args.use_slurm: - cpu_executor = create_slurm_executor(SLURM_CONFIG) - gpu_executor = create_slurm_executor( - SLURM_CONFIG, num_gpus=args.ptq_gpus, ntasks_per_node=args.ptq_gpus - ) - single_gpu_executor = create_slurm_executor(SLURM_CONFIG, num_gpus=1, ntasks_per_node=1) - else: - cpu_executor = single_gpu_executor = run.LocalExecutor() - gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.ptq_gpus) + if args.use_slurm: + cpu_executor = create_slurm_executor(SLURM_CONFIG) + ptq_gpu_executor = create_slurm_executor( + SLURM_CONFIG, num_gpus=args.ptq_gpus, ntasks_per_node=args.ptq_gpus + ) + train_gpu_executor = create_slurm_executor( + SLURM_CONFIG, nodes=args.train_nodes, num_gpus=args.train_gpus, ntasks_per_node=args.train_gpus + ) + single_gpu_executor = create_slurm_executor(SLURM_CONFIG, num_gpus=1, ntasks_per_node=1) + else: + cpu_executor = single_gpu_executor = run.LocalExecutor() + ptq_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.ptq_gpus) + train_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.train_gpus) @@ - s2 = exp.add(ptq, tail_logs=True, name="03_ptq", executor=gpu_executor, dependencies=[s1]) + s2 = exp.add(ptq, tail_logs=True, name="03_ptq", executor=ptq_gpu_executor, dependencies=[s1]) @@ - if args.use_slurm: # Set training arguments - gpu_executor.nodes = args.train_nodes - gpu_executor.gpus_per_node = gpu_executor.ntasks_per_node = args.train_gpus - else: - gpu_executor.ntasks_per_node = args.train_gpus + # use dedicated training executor; no mutation needed @@ - s4 = exp.add( - train, tail_logs=True, name="05_train", executor=gpu_executor, dependencies=train_dep - ) + s4 = exp.add( + train, tail_logs=True, name="05_train", executor=train_gpu_executor, dependencies=train_dep + ) @@ - gpu_executor.ntasks_per_node = 1 # will throw error if more than 1 task during export - exp.add( - export, - tail_logs=True, - name="07_export_hf", - executor=gpu_executor, - dependencies=[s5], - ) + exp.add( + export, + tail_logs=True, + name="07_export_hf", + executor=single_gpu_executor, + dependencies=[s5], + )Also applies to: 302-303, 311-316, 320-321, 330-336
146-153
: Guard module-level constants and SLURM_CONFIG for programmaticmain()
calls.Importing and calling
main(args)
will raise NameError for undefined globals andSLURM_CONFIG
. Add safe fallbacks and a clear error when Slurm config is missing.def main(args): + # Fallbacks if __main__ block wasn't executed + global SEQUENCE_LENGTH, MBS, GBS, TRAIN_STEPS, VAL_INTERVAL + SEQUENCE_LENGTH = globals().get("SEQUENCE_LENGTH", 4096) + MBS = globals().get("MBS", 1) + GBS = globals().get("GBS", 512) + TRAIN_STEPS = globals().get("TRAIN_STEPS", 200) + VAL_INTERVAL = globals().get("VAL_INTERVAL", 50) + if args.use_slurm and "SLURM_CONFIG" not in globals(): + raise ValueError( + "SLURM_CONFIG must be defined before calling main() with --use-slurm. " + "See the __main__ section for an example." + )
🧹 Nitpick comments (3)
examples/nemo_run/qat/nemo_qat_flow.py (3)
147-149
: Dead condition; can be removed.
--finetune-recipe
has a default, so this branch never triggers.- if not args.distill and not args.finetune_recipe: - raise ValueError("If distillation is not used, --finetune-recipe must be specified")
151-153
: Unreachable fallback formodel_name
.
model_name
is set fromargs.finetune_recipe
and is always truthy; this block never runs.- if not model_name: - model_name = os.path.basename(args.model_name)
29-31
: Avoid mutating sys.path; prefer package-relative imports.Relying on
sys.path.append
is brittle for both local and Slurm. If feasible, packageexamples
and use absolute imports (or compute script paths without altering sys.path).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/nemo_run/qat/nemo_qat_flow.py
(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/nemo_run/qat/nemo_qat_flow.py (3)
modelopt/torch/export/plugins/nemo_run.py (1)
export_most_recent_ckpt
(24-35)examples/nemo_run/common/utils.py (4)
SlurmConfig
(24-72)create_slurm_executor
(75-117)get_finetune_recipe
(120-123)read_chat_template
(126-128)examples/nemo_run/common/in_memory_mmlu.py (1)
parse_args
(24-45)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (1)
examples/nemo_run/qat/nemo_qat_flow.py (1)
338-339
: Confirm detach behavior.
exp.run(detach=True)
backgrounds the run. If interactive logs are desired by default, consider making this a CLI flag (e.g.,--detach
).@@ def get_args(): + parser.add_argument("--detach", action="store_true", help="Run experiment detached", default=False) @@ - exp.run(detach=True) + exp.run(detach=args.detach)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (5)
examples/nemo_run/qat/nemo_qat_flow.py (5)
138-142
: Fix the KV-cache flag tri-state handlingThe current implementation only uses
--enable_kv_cache
when the flag is set to True, but always defaults to--disable_kv_cache
otherwise (line 198). This prevents a neutral state where neither flag is passed.Apply this diff to properly handle the tri-state:
parser.add_argument( "--enable_kv_cache", help="Enables KV-cache quantization", action="store_true", - default=False, ) + parser.add_argument( + "--disable_kv_cache", + help="Disables KV-cache quantization", + action="store_true", + )Then update the PTQ args construction:
+ # Build KV-cache flag only when explicitly set + kv_cache_args = [] + if args.enable_kv_cache: + kv_cache_args.append("--enable_kv_cache") + elif args.disable_kv_cache: + kv_cache_args.append("--disable_kv_cache") + ptq = run.Script( "/opt/NeMo/scripts/llm/ptq.py", args=[ "-nc", bf16_ckpt_path, "-out", ptq_model_out, "--export_format", "nemo", "--algorithm", args.algorithm, "--kv_cache_qformat", args.kv_cache_qformat, - "--enable_kv_cache" if args.enable_kv_cache else "--disable_kv_cache", + *kv_cache_args, "-ctp", f"{args.ptq_gpus}", ],
247-249
: Consider tensor parallelism configuration based on GPU countSetting both tensor and pipeline parallelism to 1 by default means the model runs with data parallelism only. For larger models or when using fewer GPUs with limited memory, tensor parallelism might be required.
Consider automatically setting tensor parallelism based on the number of GPUs:
train.trainer.strategy.tensor_model_parallel_size = args.tensor_parallelism train.trainer.strategy.pipeline_model_parallel_size = args.pipeline_parallelism + # Optionally, auto-configure TP if not specified and model is large + # if args.tensor_parallelism == 1 and "70B" in args.model_name: + # train.trainer.strategy.tensor_model_parallel_size = min(8, args.train_gpus)
279-286
: Verify that GPUs are not overprovisioned in Slurm executorsCreating separate executors for different GPU counts looks correct now. This properly addresses the previous issue of reusing executors with mutated parameters.
360-360
: Slurm time format is correctThe time format "04:00:00" follows the HH:MM:SS format required by SlurmConfig validation.
226-233
: Global constants are undefined - will cause NameErrorThe code references
SEQUENCE_LENGTH
,GBS
,MBS
,TRAIN_STEPS
, andVAL_INTERVAL
which are only defined within the__main__
block (lines 375-379). This will raise a NameError whenmain()
is called.Move the constants to module scope before the
main()
function:import nemo_run as run from nemo.collections import llm # ... other imports ... +# Configurable parameters +SEQUENCE_LENGTH = 4096 +MBS = 1 +GBS = 512 +TRAIN_STEPS = 400 +VAL_INTERVAL = 50 + def get_args():And remove the duplicate definitions from the
__main__
block (lines 373-380).
🧹 Nitpick comments (5)
examples/nemo_run/qat/nemo_qat_flow.py (5)
339-340
: WAR comment indicates a NeMo export bugThe workaround mutates
train_gpu_executor.ntasks_per_node = 1
to handle a NeMo bug. This is fragile and could break if the executor is reused.Consider using a dedicated single-task executor for export to avoid mutation:
- # WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo - train_gpu_executor.ntasks_per_node = 1 # will throw error if more than 1 task during export + # WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo + export_executor = create_slurm_executor( + SLURM_CONFIG, num_gpus=args.train_gpus, ntasks_per_node=1 + ) if args.use_slurm else run.LocalExecutor() exp.add( export, tail_logs=True, name="07_export_hf", - executor=train_gpu_executor, + executor=export_executor, dependencies=[s5], )Would you like me to open an issue to track the underlying NeMo export bug?
357-370
: Empty Slurm configuration values will raise validation errorsThe SlurmConfig has empty strings for required fields like
account
,host
, anduser
. These will trigger validation errors inSlurmConfig.__post_init__()
.Add a comment to guide users on what needs to be configured:
if args.use_slurm: + # IMPORTANT: Fill in these required fields before running on Slurm SLURM_CONFIG = SlurmConfig( - account="", + account="", # REQUIRED: Your Slurm account name partition_gpu="batch", partition_cpu="cpu", time="04:00:00", container_image="nvcr.io/nvidia/nemo:25.07", env_vars={ - "HF_TOKEN": "", + "HF_TOKEN": "", # REQUIRED if using gated HF models }, use_local_tunnel=False, - host="", - user="", + host="", # REQUIRED: Slurm cluster hostname (e.g., "cluster.example.com") + user="", # REQUIRED: Your username on the cluster container_mounts=[], - job_dir="/path/to/logs", + job_dir="/path/to/logs", # REQUIRED: Directory for job logs on cluster identity=None, )
143-144
: Parallelism arguments lack help textThe tensor and pipeline parallelism arguments are missing help descriptions, making their purpose unclear to users.
- parser.add_argument("--tensor_parallelism", type=int, default=1) - parser.add_argument("--pipeline_parallelism", type=int, default=1) + parser.add_argument( + "--tensor_parallelism", + type=int, + default=1, + help="Tensor parallelism degree for model training" + ) + parser.add_argument( + "--pipeline_parallelism", + type=int, + default=1, + help="Pipeline parallelism degree for model training" + )
347-347
: Consider adding timeout handling for long-running experimentsThe experiment runs with
detach=True
, which is good for long-running jobs. However, there's no timeout or monitoring mechanism for stuck jobs.Consider adding a comment about monitoring:
- exp.run(detach=True) + exp.run(detach=True) # Detached mode for long-running jobs + # Monitor job progress via: exp.status() or check logs at {exp_dir}
158-164
: Resolve process_openscience.py path to an absolute pathFile: examples/nemo_run/qat/nemo_qat_flow.py (lines 158–164)
Confirmed examples/nemo_run/common/process_openscience.py exists in the repo; the Slurm branch’s use of a CWD-relative literal can still fail at runtime — resolve the script path relative to this file and validate existence before calling run.Script. Suggested replacement:
openscience_path = os.path.abspath(os.path.join(os.path.dirname(file), "..", "common", "process_openscience.py"))
if not os.path.exists(openscience_path):
raise FileNotFoundError(f"OpenScience script not found at {openscience_path}")
openscience_data = run.Script(openscience_path)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/nemo_run/qat/nemo_qat_flow.py
(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/nemo_run/qat/nemo_qat_flow.py (3)
modelopt/torch/export/plugins/nemo_run.py (1)
export_most_recent_ckpt
(24-35)examples/nemo_run/common/utils.py (4)
SlurmConfig
(24-72)create_slurm_executor
(75-117)get_finetune_recipe
(120-123)read_chat_template
(126-128)examples/nemo_run/common/in_memory_mmlu.py (1)
parse_args
(24-45)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (2)
examples/nemo_run/qat/nemo_qat_flow.py (2)
257-276
: MMLU evaluation script path logic matches OpenScience patternThe conditional path resolution for the MMLU script follows the same pattern as the OpenScience script, maintaining consistency.
235-239
: Add graceful handling for missing distillation recipeget_finetune_recipe already validates and raises ValueError (examples/nemo_run/common/utils.py:120–123). distillation_recipe is imported at examples/nemo_run/qat/nemo_qat_flow.py:24 but its implementation is not in this repo — wrap the distillation call (examples/nemo_run/qat/nemo_qat_flow.py:235–239) in a presence check or try/except to surface a clear error like the finetune path.
e4c8597
to
1f9996a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (6)
modelopt/torch/_deploy/_runtime/tensorrt/tensorrt_utils.py (2)
65-74
: Fix return type mismatch in load_engine (function returns a tuple).The annotation/doc promise an engine, but the function returns
(engine | None, error_str)
. Correct the signature/doc to avoid downstream type/logic bugs.Apply:
-def load_engine(buffer: bytes, log_level: int = trt.Logger.ERROR) -> trt.tensorrt.ICudaEngine: - """Load a TensorRT engine from engine data and return.""" +def load_engine( + buffer: bytes, log_level: int = trt.Logger.ERROR +) -> tuple[trt.tensorrt.ICudaEngine | None, str]: + """Load a TensorRT engine from engine data. + + Returns: + (engine | None, error_message) + """
170-179
: Avoid double‑hashing; fix doc typo.
hashlib.sha256(engine_bytes)
already digests the data; callingupdate(engine_bytes)
again computes SHA256(engine_bytes || engine_bytes).-def prepend_hash_to_bytes(engine_bytes: bytes) -> bytes: - """Prepend the engine bytes with the SHA256 hash of the engine bytes - This has will serve as a unique identifier for the engine and will be used to manage - TRTSessions in the TRTClient. - """ - hash_object = hashlib.sha256(engine_bytes) - hash_object.update(engine_bytes) - hash_bytes = hash_object.digest() - engine_bytes = hash_bytes + engine_bytes - return engine_bytes +def prepend_hash_to_bytes(engine_bytes: bytes) -> bytes: + """Prepend the engine bytes with the SHA256 hash of the engine bytes. + This hash serves as a unique identifier for the engine and is used to manage + TRTSessions in the TRTClient. + """ + hash_bytes = hashlib.sha256(engine_bytes).digest() + return hash_bytes + engine_bytesexamples/speculative_decoding/launch.sh (1)
92-95
: Fix divide‑by‑zero when no GPU is present.GPU_COUNT can be 0 causing an arithmetic error; fall back to 1 before dividing. Apply at examples/speculative_decoding/launch.sh (around lines 92–95):
GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())") -# Calculate save_steps -DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT)) +# Calculate save_steps (fallback to 1 when no GPU is detected) +if [[ -z "$GPU_COUNT" || "$GPU_COUNT" -le 0 ]]; then + GPU_COUNT=1 +fi +DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT))Test on a CPU-only node to confirm the script no longer exits early.
modelopt/onnx/quantization/qdq_utils.py (1)
639-642
: Create FP8 ONNX tensors with correct dtype (bug fix).
numpy_helper.from_array
will set dtype to UINT8, not FLOAT8. Useonnx.helper.make_tensor(..., data_type=FLOAT8, raw=True)
like the MXFP8 path below.Apply this diff:
-def _create_fp8_tensor(scaled: np.ndarray, weight_name: str) -> onnx.TensorProto: - """Create a FLOAT8E4M3FN tensor directly from numpy array.""" - fp8_data = _cast_fp8(scaled) - return onnx.numpy_helper.from_array(fp8_data, weight_name) +def _create_fp8_tensor(scaled: np.ndarray, weight_name: str) -> onnx.TensorProto: + """Create a FLOAT8E4M3FN tensor directly from numpy array.""" + fp8_bytes = _cast_fp8(scaled).tobytes() + return onnx.helper.make_tensor( + name=weight_name, + data_type=onnx_dtype_map["Float8"], + dims=list(scaled.shape), + vals=fp8_bytes, + raw=True, + )tests/unit/onnx/test_qdq_utils.py (1)
35-36
: Fix MatMul shape inconsistency in the synthetic graph (pre/post-quantization).As written, the MatMul inputs are dimensionally incompatible both before and after the Reshape/Transpose removal. Make the reshape produce (8, 32) and keep the post‑transpose as (32, 8), then drive MatMul with input [..., 32] to yield output [..., 8]. Also avoid orphaning the original scale initializer when
constant_scale=True
.Apply:
- reshape_shape = np.array([16, 16], dtype=np.int64) + reshape_shape = np.array([8, 32], dtype=np.int64)- reshape_output_info = helper.make_tensor_value_info( - "reshape_output", TensorProto.FLOAT, [16, 16] - ) + reshape_output_info = helper.make_tensor_value_info( + "reshape_output", TensorProto.FLOAT, [8, 32] + )- graph = helper.make_graph( - nodes=nodes, - name="test_graph", - inputs=[input_tensor], - outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [None, 16])], - initializer=[weight_tensor, scale_tensor], - value_info=[reshape_output_info], - ) + initializers = [weight_tensor] if constant_scale else [weight_tensor, scale_tensor] + graph = helper.make_graph( + nodes=nodes, + name="test_graph", + inputs=[input_tensor], # make sure this is [None, 32] above where defined + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [None, 8])], + initializer=initializers, + value_info=[reshape_output_info], + )Also update the input tensor shape at line 38 to
[None, 32]
to match. If you prefer not to change I/O shapes, alternatively setreshape_shape = [32, 8]
andperm=[1,0]
to be a no‑op on shape.Also applies to: 91-96, 100-106
examples/speculative_decoding/export_hf_checkpoint.py (1)
38-49
: Move execution under a main guard.Avoids side effects on import and clarifies entrypoint.
Apply this diff:
-mto.enable_huggingface_checkpointing() - -args = parse_args() -model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype="auto") -model.eval() -with torch.inference_mode(): - export_hf_checkpoint( - model, # The quantized model. - export_dir=args.export_path, # The directory where the exported files will be stored. - ) -print(f"Exported checkpoint to {args.export_path}") +def main(): + mto.enable_huggingface_checkpointing() + args = parse_args() + model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype="auto") + model.eval() + with torch.inference_mode(): + export_hf_checkpoint(model, export_dir=args.export_path) + print(f"Exported checkpoint to {args.export_path}") + +if __name__ == "__main__": + main()
♻️ Duplicate comments (5)
examples/nemo_run/qat/nemo_qat_flow.py (5)
137-145
: KV‑cache flag should be tri‑state; avoid forcing disable by default.Only append --enable_kv_cache/--disable_kv_cache when explicitly requested.
Apply:
- parser.add_argument( - "--enable_kv_cache", - help="Enables KV-cache quantization", - action="store_true", - default=False, - ) + kv_group = parser.add_mutually_exclusive_group() + kv_group.add_argument( + "--enable_kv_cache", + dest="enable_kv_cache", + action="store_true", + help="Enables KV-cache quantization", + ) + kv_group.add_argument( + "--disable_kv_cache", + dest="enable_kv_cache", + action="store_false", + help="Disables KV-cache quantization", + ) + parser.set_defaults(enable_kv_cache=None)- ptq = run.Script( + # Build KV-cache flag only when explicitly set + kv_cache_flag = ( + ["--enable_kv_cache"] if args.enable_kv_cache is True + else (["--disable_kv_cache"] if args.enable_kv_cache is False else []) + ) + ptq = run.Script( "/opt/NeMo/scripts/llm/ptq.py", args=[ @@ - "--kv_cache_qformat", - args.kv_cache_qformat, - "--enable_kv_cache" if args.enable_kv_cache else "--disable_kv_cache", + "--kv_cache_qformat", + args.kv_cache_qformat, + *kv_cache_flag,Also applies to: 185-203
360-361
: Slurm time format fix acknowledged.Using HH:MM:SS (“04:00:00”) resolves SlurmConfig validation.
338-346
: Don’t mutate the train executor for export; use the single‑GPU executor to avoid Slurm resource mismatches.Changing only ntasks_per_node leaves gres/gpus_per_node stale on Slurm and can cause mis-scheduled jobs. Reuse the already-created single_gpu_executor for export and drop the mutation.
Apply:
- # WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo - train_gpu_executor.ntasks_per_node = 1 # will throw error if more than 1 task during export - exp.add( + exp.add( export, tail_logs=True, name="07_export_hf", - executor=train_gpu_executor, + executor=single_gpu_executor, dependencies=[s5], )Also applies to: 339-339
222-233
: Guard globals and config for programmatic invocation.SEQUENCE_LENGTH/GBS/MBS/TRAIN_STEPS/VAL_INTERVAL and SLURM_CONFIG are referenced inside main() but defined under main. Importing and calling main(args) will raise NameError.
Apply:
def main(args): + # Ensure module-scope defaults exist when called programmatically + global SEQUENCE_LENGTH, MBS, GBS, TRAIN_STEPS, VAL_INTERVAL, SLURM_CONFIG + SEQUENCE_LENGTH = globals().get("SEQUENCE_LENGTH", 4096) + MBS = globals().get("MBS", 1) + GBS = globals().get("GBS", 512) + TRAIN_STEPS = globals().get("TRAIN_STEPS", 400) + VAL_INTERVAL = globals().get("VAL_INTERVAL", 50) + if args.use_slurm and "SLURM_CONFIG" not in globals(): + raise ValueError( + "SLURM_CONFIG must be defined (see __main__ block) when --use-slurm is set." + )Also applies to: 278-287, 375-380
283-286
: Pass nodes to the Slurm training executor.Multi-node training won’t be honored without nodes=args.train_nodes.
Apply:
- train_gpu_executor = create_slurm_executor( - SLURM_CONFIG, num_gpus=args.train_gpus, ntasks_per_node=args.train_gpus - ) + train_gpu_executor = create_slurm_executor( + SLURM_CONFIG, + nodes=args.train_nodes, + num_gpus=args.train_gpus, + ntasks_per_node=args.train_gpus, + )
🧹 Nitpick comments (77)
modelopt/onnx/autocast/precisionconverter.py (4)
179-184
: Use HasField and clear the oneof when anonymizing dim_value.
if d.dim_value:
misses dimensions explicitly set to 0 and doesn't make the oneof switch explicit. PreferHasField("dim_value")
and clear the field before settingdim_param
. If the intent is to preserve explicit 0-dims, gate withand d.dim_value != 0
.Apply this diff in the value_info loop:
- for vi in self.model.graph.value_info: - vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED - for idx, d in enumerate(vi.type.tensor_type.shape.dim): - if d.dim_value: - vi.type.tensor_type.shape.dim[idx].dim_param = "unk" + for vi in self.model.graph.value_info: + vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED + for d in vi.type.tensor_type.shape.dim: + # If preserving explicit 0-dims, add: and d.dim_value != 0 + if d.HasField("dim_value"): + d.ClearField("dim_value") + d.dim_param = "unk"
185-189
: Mirror the safer oneof handling for graph outputs.Same concern as above: use
HasField("dim_value")
and clear it before settingdim_param
to avoid silently skipping explicit 0 or leaving implicit defaults ambiguous.Apply this diff in the outputs loop:
- for out in self.model.graph.output: - out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED - for idx, d in enumerate(out.type.tensor_type.shape.dim): - if d.dim_value: - out.type.tensor_type.shape.dim[idx].dim_param = "unk" + for out in self.model.graph.output: + out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED + for d in out.type.tensor_type.shape.dim: + if d.HasField("dim_value"): + d.ClearField("dim_value") + d.dim_param = "unk"
191-194
: Defaulting UNDEFINED types to low precision can mis-type non-float tensors.Blindly setting all remaining UNDEFINED
value_info
toself.low_precision_type
risks clobbering indices/bool/int tensors (e.g., TopK/ArgMax paths) and may only be caught late by strict checks. Gate this by the original element type (when known) or by neighboring float-only usage.Apply this diff to constrain the assignment using the original
value_info_map
:- self._ensure_types_are_defined() + self._ensure_types_are_defined()And update
_ensure_types_are_defined
per next comment.
207-213
: Constrain fallback type assignment to originally-float tensors; add optional neighbor hint.Use the original
value_info_map
to only default tensors that were float-typed before we cleared metadata. Optionally, fall back to a lightweight neighbor check.Apply this diff:
- def _ensure_types_are_defined(self): - """Ensure that all tensor types are defined.""" - for vi in self.model.graph.value_info: - if vi.type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED: - vi.type.tensor_type.elem_type = self.low_precision_type.onnx_type + def _ensure_types_are_defined(self): + """Ensure float tensors have defined types without clobbering non-float tensors.""" + # Build a quick lookup of original elem_types (pre-mutation) + orig_types = { + name: info.type.tensor_type.elem_type + for name, info in (self.value_info_map or {}).items() + } + for vi in self.model.graph.value_info: + if vi.type.tensor_type.elem_type != onnx.TensorProto.UNDEFINED: + continue + name = vi.name + # Prefer original knowledge + if name in orig_types and orig_types[name] in ONNX_TYPES: + vi.type.tensor_type.elem_type = self.low_precision_type.onnx_type + continue + # Lightweight neighbor hint: if any producer/consumer is a Cast or float op, assume float + producers = utils.get_producer_nodes(self.model, name) + consumers = utils.get_consumer_nodes(self.model, name) + neighbors = producers + consumers + if any(n.op_type in {"Cast", "Add", "Mul", "MatMul", "Conv"} for n in neighbors): + vi.type.tensor_type.elem_type = self.low_precision_type.onnx_typePlease confirm intent for 0-dim behavior (preserve vs. anonymize). Also consider adding a unit test covering:
- value_info for ArgMax/TopK indices stays INT64;
- tensors left UNDEFINED after first inference but belonging to float paths get defaulted to
low_precision_type
;- shapes with explicit 0 dims are preserved.
modelopt/torch/_deploy/utils/torch_onnx.py (2)
490-492
: Clarify assertion message for MXFP8/INT4 mixed precisionMessage reads as BF16-specific; make the constraint explicit for all quantized cases here.
- assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet" + assert weights_dtype == "fp16", "Only FP16 weights are supported when the model is MXFP8 or INT4 quantized (BF16 unsupported)."
333-339
: Avoid mutable default for dynamic_axes and pass only when providedUsing
{}
as a default can lead to unintended shared state; also skip passing empty dict to export.- dynamic_axes: dict = {}, + dynamic_axes: dict | None = None, @@ - if not dynamo_export and Version(torch.__version__) >= Version("2.8"): - additional_kwargs["dynamic_axes"] = dynamic_axes + if not dynamo_export and Version(torch.__version__) >= Version("2.8") and dynamic_axes: + additional_kwargs["dynamic_axes"] = dynamic_axesAlso applies to: 435-437
examples/nemo_run/qat/nemo_qat_flow.py (2)
29-31
: Avoid sys.path mutation; prefer package/relative imports.Using sys.path.append can surprise downstream tools. If feasible, make examples a package and import via examples.nemo_run.common.utils or use relative imports.
149-156
: Separate recipe name from model naming to avoid confusion.model_name is overloaded (recipe vs HF model). The fallback branch is unreachable. Consider recipe_name = args.finetune_recipe for llm module lookup and model_dir = os.path.basename(args.model_name) for path naming (bf16_ckpt_path/ptq_model_out).
Also applies to: 173-184
examples/onnx_ptq/torch_quant_to_onnx.py (5)
86-92
: Avoid re-instantiating/downloading the model just to read input_sizeConstructing a model here duplicates work (and with pretrained=True may trigger weight downloads). Derive input_size from timm’s pretrained cfg, falling back to a lightweight model with pretrained=False.
Apply this diff:
def get_model_input_shape(model_name, batch_size): """Get the input shape from timm model configuration.""" - model = timm.create_model(model_name, pretrained=True, num_classes=1000) - data_config = timm.data.resolve_model_data_config(model) - input_size = data_config["input_size"] - return (batch_size, *tuple(input_size)) # Add batch dimension + # Prefer config path to avoid heavyweight instantiation / downloads. + try: + cfg = timm.get_pretrained_cfg(model_name) + input_size = tuple(getattr(cfg, "input_size", cfg.get("input_size"))) + except Exception: + # Fallback: create a lightweight model without pretrained weights. + model = timm.create_model(model_name, pretrained=False, num_classes=1000) + data_config = timm.data.resolve_model_data_config(model) + input_size = data_config["input_size"] + return (batch_size, *input_size) # Add batch dimension
122-127
: Validate --batch_size (> 0) to avoid runtime errorsNegative/zero batch sizes will break DataLoader and export assumptions. Add an argparse validator.
Apply this diff:
- parser.add_argument( - "--batch_size", - type=int, - default=1, - help="Batch size for calibration and ONNX model export.", - ) + parser.add_argument( + "--batch_size", + type=positive_int, + default=1, + help="Batch size for calibration and ONNX model export (must be > 0).", + )Add this helper near the other imports:
def positive_int(v: str) -> int: iv = int(v) if iv <= 0: raise argparse.ArgumentTypeError("batch_size must be a positive integer") return iv
132-132
: Remove duplicate model construction; derive input_shape from the instantiated modelYou construct the model again at Line 136. Compute input_shape from that instance instead of creating one inside get_model_input_shape.
Apply this diff:
- # Get input shape from model config - input_shape = get_model_input_shape(args.timm_model_name, args.batch_size) - - # Create model and move to appropriate device + # Create model and move to appropriate device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = timm.create_model(args.timm_model_name, pretrained=True, num_classes=1000).to(device) + # Derive input shape from the instantiated model (avoid extra construction) + data_config = timm.data.resolve_model_data_config(model) + input_shape = (args.batch_size, *tuple(data_config["input_size"]))
55-67
: Don’t preload calibration samples onto GPU; move batches in the forward loop + no_grad/evalPreloading each sample to device inflates GPU memory and fights DataLoader workers. Keep tensors on CPU, use pin_memory, and move inside the loop with inference_mode. Also ensure eval() during calibration.
Apply this diff:
def load_calibration_data(model_name, data_size, batch_size, device): @@ - images = dataset["train"][:data_size]["image"] - calib_tensor = [transforms(img) for img in images] - calib_tensor = [t.to(device) for t in calib_tensor] + images = dataset["train"][:data_size]["image"] + calib_tensor = [transforms(img) for img in images] return torch.utils.data.DataLoader( - calib_tensor, batch_size=batch_size, shuffle=True, num_workers=4 + calib_tensor, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True ) @@ - def forward_loop(model): - for batch in data_loader: - model(batch) + def forward_loop(model): + model.eval() + with torch.inference_mode(): + for batch in data_loader: + if isinstance(batch, (list, tuple)): + batch = [b.to(next(model.parameters()).device, non_blocking=True) for b in batch] + else: + batch = batch.to(next(model.parameters()).device, non_blocking=True) + model(batch)Also applies to: 74-79
155-161
: Expose dynamic_axes in export_to_onnx and forward it to the ONNX exporterexport_to_onnx (examples/onnx_ptq/download_example_onnx.py) builds a fixed-size dummy_input and calls get_onnx_bytes without any dynamic_axes, so the produced ONNX is fixed-batch.
- Change export_to_onnx signature to accept an optional dynamic_axes and forward it to get_onnx_bytes / torch.onnx.export (or update modelopt/torch/_deploy/utils/torch_onnx.py if needed).
- Update the call in examples/onnx_ptq/torch_quant_to_onnx.py to pass, e.g. dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}.
modelopt/torch/_deploy/_runtime/tensorrt/tensorrt_utils.py (2)
60-63
: Return bytes, not bytearray, to match type hint.-def get_engine_bytes(engine: trt.tensorrt.ICudaEngine) -> bytes: - """Return serialized TensorRT engine bytes.""" - return bytearray(engine.serialize()) # type: ignore[return-value] +def get_engine_bytes(engine: trt.tensorrt.ICudaEngine) -> bytes: + """Return serialized TensorRT engine bytes.""" + return bytes(engine.serialize())
134-141
: Minor: avoid parsing ONNX twice in calib_data_generator.Parse once and reuse for input names and batch size.
-def calib_data_generator(onnx_bytes: bytes, input_tensors: list[np.ndarray]): +def calib_data_generator(onnx_bytes: bytes, input_tensors: list[np.ndarray]): """The calibation data generator that yields calibration feed_dict to tensorrt.""" - input_names = get_onnx_input_names(onnx.load_from_string(onnx_bytes)) - - batch_size = get_batch_size(onnx.load_from_string(onnx_bytes)) + model = onnx.load_from_string(onnx_bytes) + input_names = get_onnx_input_names(model) + batch_size = get_batch_size(model)modelopt/torch/_deploy/_runtime/tensorrt/engine_builder.py (1)
103-110
: STRONGLY_TYPED intentionally treated as low‑bit; mapping present — document rationaleVerified: TRTMode.STRONGLY_TYPED is included in _is_low_bit_mode and TRT_MODE_FLAGS maps it to ["--stronglyTyped"] (modelopt/torch/_deploy/_runtime/tensorrt/constants.py); examples/README also use --builderOptimizationLevel=4 with --stronglyTyped, so forcing opt‑level=4 appears intentional.
- Optional: add a brief inline comment explaining why strongly‑typed forces builderOptimizationLevel=4 or gate this behavior if different opt‑levels are expected for fp16/bf16 strongly‑typed builds.
examples/windows/onnx_ptq/genai_llm/requirements.txt (1)
2-2
: Avoid ONNX version skew with setup.py (extras use ~=1.19.0).Pinning 1.18.0 here while extras require ~=1.19.0 can cause resolver conflicts in mixed flows (dev installs extras + runs this example). Either align, or clearly scope this pin to Windows-only pipelines.
Two safe options:
-onnx==1.18.0 +onnx>=1.18,<1.20 # Aligns with repo extras (~=1.19.x) while keeping 1.18 compatibilityor keep 1.18.0 but document intent:
+# Note: ONNX pinned to 1.18.0 for Windows PTQ sample stability; repo extras target ~=1.19.x. onnx==1.18.0
Please confirm which constraint you prefer and that CI matrices won’t install both this file and the onnx extra together.
tests/unit/onnx/test_qdq_rules_int8.py (1)
98-101
: Guard against graph inputs when inspecting Add inputs.This list comp will crash if an Add input is a graph input (no producers). Mirror the new guard here.
- add_input_ops = [inp.inputs[0].op for inp in add_node.inputs] - assert np.isin(add_input_ops, ["Conv", "DequantizeLinear"]).all(), ( + add_inputs_with_producer = [inp for inp in add_node.inputs if inp.inputs] + add_input_ops = [inp.inputs[0].op for inp in add_inputs_with_producer] + assert set(add_input_ops) <= {"Conv", "DequantizeLinear"} and len(add_input_ops) == 2, ( f"Add node {add_node.name} was not quantized correctly!" )Please confirm test models never wire raw graph inputs into Add in this check path; if they can, the len==2 assertion is the right expectation.
examples/speculative_decoding/launch.sh (1)
153-154
: Make reporting backend configurable (default remains tensorboard).Hard‑coding tensorboard is OK, but a flag improves ergonomics and parity with wandb.
- --report_to tensorboard \ + --report_to ${REPORT_TO:-tensorboard} \Add parser support:
@@ --do_eval*) if [[ "$1" != *=* ]]; then shift; fi DO_EVAL="${1#*=}" ;; + --report_to*) + if [[ "$1" != *=* ]]; then shift; fi + REPORT_TO="${1#*=}" + ;;modelopt/torch/export/plugins/__init__.py (1)
23-24
: Avoid hard dependency via wildcard import; degrade gracefully.Unconditional
from .hf_spec_export import *
can pull in heavy/optional deps and pollute the namespace. Prefer a guarded import.-from .hf_spec_export import * +try: + from .hf_spec_export import * +except ImportError: + # hf_spec_export depends on optional stacks; keep plugins package importable without them + passIf hf_spec_export has no optional deps, we can still import explicitly‑named symbols instead of
*
to keep API surface tight. Want me to propose an explicit all?modelopt/onnx/quantization/int8.py (1)
127-129
: Defaulting to fp16 is a behavior change; confirm CPU/ORT compatibility.This will convert models to fp16/bf16 by default. Ensure:
- CPU‑only pipelines and DML EP handle the converted graph without perf/accuracy regressions.
- CLI/docs reflect the new default and how to opt out.
Optionally allow disabling via None:
- high_precision_dtype: str = "fp16", + high_precision_dtype: str | None = "fp16", @@ - if high_precision_dtype in ["fp16", "bf16"]: + if high_precision_dtype in ["fp16", "bf16"]: ...Please run a quick smoke on a CPU runner with ORT CPU EP to verify no op‑type falls back or dtype unsupported errors.
.github/workflows/unit_tests.yml (1)
122-127
: Aggregator job added — consider making it resilient to upstream skips/failures.Looks good for enforcing a single PR “unit tests complete” signal. To make the job always run and clearly fail when any prerequisite fails or is skipped, add job-level
if: always()
and an explicit failure check againstneeds.*.result
.- unit-pr-required-check: - if: github.event_name == 'pull_request' + unit-pr-required-check: + # Always evaluate this job for PRs so it can fail if any prereq failed/skipped + if: github.event_name == 'pull_request' && always() needs: [linux, windows, multi-py, multi-torch, multi-transformers, partial-install] runs-on: ubuntu-latest steps: - - run: echo "All PR unit test jobs completed" + - name: Verify prerequisite jobs + run: | + echo "linux: ${{ needs.linux.result }}" + echo "windows: ${{ needs.windows.result }}" + echo "multi-py: ${{ needs.multi-py.result }}" + echo "multi-torch: ${{ needs.multi-torch.result }}" + echo "multi-transformers: ${{ needs.multi-transformers.result }}" + echo "partial-install: ${{ needs['partial-install'].result }}" + if [[ "${{ needs.linux.result }}" != "success" || \ + "${{ needs.windows.result }}" != "success" || \ + "${{ needs.multi-py.result }}" != "success" || \ + "${{ needs.multi-torch.result }}" != "success" || \ + "${{ needs.multi-transformers.result }}" != "success" || \ + "${{ needs['partial-install'].result }}" != "success" ]]; then + echo "One or more unit test jobs did not succeed" + exit 1 + fi + - run: echo "All PR unit test jobs completed"tests/_test_utils/onnx_quantization/utils.py (1)
23-29
: Unwrap chains of Cast nodes, not just a single Cast.Current logic handles only one Cast. Robustly skip over multiple Casts to reach DQ.
- producer = node.i(inp_idx) - # Quantized path may include a Cast right after DQ - if producer and producer.op == "Cast": - producer = producer.i(0) + producer = node.i(inp_idx) + # Quantized path may include one or more Casts right after DQ + while producer and producer.op == "Cast": + producer = producer.i(0)modelopt/onnx/trt_utils.py (1)
419-424
: Good: avoid empty quantize mappings.Creating
custom_ops_to_quantize[op_type]
only when IOs exist is correct and prevents misleading empty config. Consider mirroring this forcustom_ops_to_cast
(skip when both lists are empty) for symmetry and cleaner downstream handling.tests/_test_utils/import_helper.py (1)
80-93
: Fix skip message logic and align skip behavior.Message says “less than required” but you skip when ONNX is greater than 1.18. Also, add
allow_module_level=True
for consistency.-def skip_if_onnx_version_above_1_18(): +def skip_if_onnx_version_above_1_18(): package_name = "onnx" required_version = "1.18.0" try: installed_version = importlib.metadata.version(package_name) except importlib.metadata.PackageNotFoundError: - pytest.skip(f"{package_name} is not installed") + pytest.skip(f"{package_name} is not installed", allow_module_level=True) if version.parse(installed_version) > version.parse(required_version): - pytest.skip( - f"{package_name} version {installed_version} is less than required {required_version}" - ) + pytest.skip( + f"{package_name} version {installed_version} is greater than allowed {required_version}", + allow_module_level=True, + )tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (2)
43-44
: Place version gating at module import time to save setup cost.Optional: invoke the skip helper at module scope so the entire file is skipped before CUDA/model setup.
- def test_int4_awq(tmp_path): - skip_if_onnx_version_above_1_18() +skip_if_onnx_version_above_1_18() + +def test_int4_awq(tmp_path):
119-121
: Order of skips is fine; minor consistency nit.Calling ONNX version skip before libcudnn skip is OK. If you move the version skip to module scope, keep
skip_if_no_libcudnn()
first inside the test to preserve current behavior for environments without cuDNN.examples/speculative_decoding/ar_validate.py (2)
29-31
: Default sample count increased — confirm runtime/CI budget.Bumping
num_samples
to 80 increases runtime. If this is used in CI, consider parameterizing via env or keeping a lower default for CI.
58-63
: CLI defaults updated — keep docs/launch scripts in sync.Ensure any docs or scripts referencing
--osl
and--num_samples
defaults are updated..github/workflows/gpu_tests.yml (1)
85-92
: Required GPU check logic is sound; add a success echo for clarity.The conditional failure is correct. Add a final unconditional echo so the job shows a green step on success.
gpu-pr-required-check: # Run even if gpu-tests-pr is skipped if: ${{ startsWith(github.ref, 'refs/heads/pull-request/') && always() }} needs: [check-file-changes, gpu-tests-pr] runs-on: ubuntu-latest steps: - name: Required GPU tests did not succeed if: ${{ needs.check-file-changes.result != 'success' || (needs.check-file-changes.outputs.any_changed == 'true' && needs.gpu-tests-pr.result != 'success') }} run: exit 1 + - run: echo "GPU test requirements satisfied"
examples/speculative_decoding/server_generate.py (1)
155-158
: Avoid NameError when printing prompt in exceptions.
prompt
is undefined in the chat path; the except handler can raise another error.Apply this diff:
- except Exception as e: - print(e) - print(prompt) - print("Failed to generate data") + except Exception as e: + print(e) + if "prompt" in locals(): + print(prompt) + print("Failed to generate data")examples/speculative_decoding/calibrate_draft_vocab.py (1)
31-35
: Validate --draft_vocab_size and clarify help.Add basic bounds checks to fail fast and document expectations.
Apply this diff:
- "--draft_vocab_size", - type=int, - required=True, - help="Draft vocab size", + "--draft_vocab_size", + type=int, + required=True, + help="Draft vocab size (must be > 0 and <= tokenizer vocab size)",And add after parsing (right after Line 45):
args = parser.parse_args() +if args.draft_vocab_size <= 0: + raise ValueError("--draft_vocab_size must be > 0")Optionally check against tokenizer size after loading it:
tokenizer = AutoTokenizer.from_pretrained(args.model) +if hasattr(tokenizer, "vocab") and args.draft_vocab_size > len(tokenizer.vocab): + raise ValueError(f"--draft_vocab_size ({args.draft_vocab_size}) exceeds tokenizer size ({len(tokenizer.vocab)})")modelopt/torch/export/plugins/hf_spec_export.py (4)
43-49
: Report all missing keys at once for better diagnostics.Collect and display the full set of missing required keys; current code raises on the first one.
def _check_state_dict_keys_match(draft_model: nn.Module, required_items: dict): """Check if the state dict keys match.""" draft_keys = set(draft_model.state_dict().keys()) - for required_key in required_items: - if required_key not in draft_keys: - raise ValueError(f"State dict keys mismatch!\nMissing in draft model: {required_key}") + missing = [k for k in required_items if k not in draft_keys] + if missing: + raise ValueError( + "State dict keys mismatch! Missing in draft model: " + + ", ".join(sorted(missing)) + )
63-75
: Guard against missing eagle_module attribute.If
_modelopt_state
indicateseagle
butmodel.eagle_module
is absent, this will raiseAttributeError
.Apply this diff:
- _check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"]) + if not hasattr(model, "eagle_module"): + raise ValueError("Eagle mode detected but model.eagle_module is missing") + _check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"])
76-79
: Validate LM head fallback shape to avoid silent mismatch.When
eagle_lm_head.weight
is missing, copying baselm_head.weight
may mis-shape the draft head (e.g., draft vs base vocab sizes). Validate and fail fast.- if "eagle_lm_head.weight" not in eagle_state: - export_state_dict["lm_head.weight"] = model.state_dict()["lm_head.weight"] + if "eagle_lm_head.weight" not in eagle_state: + base_lm = model.state_dict().get("lm_head.weight") + if base_lm is None: + raise ValueError("lm_head.weight not found in base model for fallback") + # Optional: if d2t present, ensure vocab alignment + d2t = export_state_dict.get("d2t") + if d2t is not None and base_lm.shape[0] != d2t.numel(): + raise ValueError( + f"LM head vocab size ({base_lm.shape[0]}) does not match draft vocab size ({d2t.numel()})" + ) + export_state_dict["lm_head.weight"] = base_lm
94-131
: Set transformers_version automatically when missing.Helps downstream tooling that expects this field.
def set_config_if_spec_decoding(model: nn.Module, config_data: dict): @@ - template_config = { + template_config = { @@ "transformers_version": None, @@ } @@ - for key in template_config: + for key in template_config: value = template_config[key] @@ template_config[key] = new_value + # Populate transformers_version if available + if template_config.get("transformers_version") is None: + try: + import transformers + template_config["transformers_version"] = transformers.__version__ + except Exception: + passmodelopt/onnx/quantization/qdq_utils.py (3)
620-635
: Verify FP4 packing axis; consider packing along the last dimension for consistency.Current packing flattens and halves the first dimension, whereas NVFP4 tensor packing typically pairs along the last axis. If consumer kernels expect last-axis packing, this will mis-shape weights.
Alternative packing along the last axis:
- array_f32_t_shape = array_f32_t.shape - assert array_f32_t_shape[0] % 2 == 0, "array_f32_t_shape[0] must be divisible by 2" - array_f4_t_shape = (array_f32_t_shape[0] // 2, *array_f32_t_shape[1:]) + shape = list(array_f32_t.shape) + assert shape[-1] % 2 == 0, "last dimension must be divisible by 2 for FP4 packing" + packed_shape = [*shape[:-1], shape[-1] // 2] @@ - array_f4_t = array_f4_t.flatten() - array_f4_t_packed = (array_f4_t[::2] | (array_f4_t[1::2] << 4)).reshape(array_f4_t_shape) + array_f4_t_packed = (array_f4_t[..., 0::2] | (array_f4_t[..., 1::2] << 4)).contiguous() + array_f4_t_packed = array_f4_t_packed.reshape(packed_shape)If first-axis packing is intentional for downstream ORT/TRT consumption, please confirm and ignore this suggestion.
943-945
: Don’t rely on “Constant” substring in tensor names; check producer op_type.Name-based matching is brittle. Use the producer map to find and remove the Constant node.
Apply this diff:
- # Remove constant node from reshape node - shape_constant_name = next(input for input in reshape_node.input if "Constant" in input) - nodes_to_remove.append(tensor_producer_map[shape_constant_name].name) + # Remove Constant node feeding the Reshape shape input + shape_input = next(inp for inp in reshape_node.input if inp != node.output[0]) + shape_producer = tensor_producer_map.get(shape_input) + if shape_producer and shape_producer.op_type == "Constant": + nodes_to_remove.append(shape_producer.name)
869-869
: Avoid forcing ir_version downgrade unless strictly required.Setting
onnx_model.ir_version = 10
can regress compatibility/features. Prefer preserving original IR or conditionally lowering only when exporters/ORT demand it.If this is required for current ORT targets, please document the constraint and link the issue.
tests/unit/onnx/test_qdq_utils.py (1)
70-77
: Redundant Cast to FLOAT.
DequantizeLinear
already produces FLOAT. Casting to FLOAT again is harmless but unnecessary noise for the unit graph. Consider casting to FLOAT16 if the aim is to exercise downcast logic, or remove this node to keep the graph minimal.modelopt/torch/speculative/plugins/transformers.py (2)
817-819
: Guard DynamicCache conversion when cache is None.
DynamicCache.from_legacy_cache(None)
may not be supported across all HF versions. Add a None check.- if not isinstance(past_key_values, Cache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) + if past_key_values is not None and not isinstance(past_key_values, Cache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values)- if not isinstance(eagle_cache, Cache): - eagle_cache = DynamicCache.from_legacy_cache(eagle_cache) + if eagle_cache is not None and not isinstance(eagle_cache, Cache): + eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)Also applies to: 855-857
720-725
: Make draft‑vocab mapping device‑agnostic.Index tensor should live on the same device as
full_logits
.- reverse_mapping = ( - torch.arange(len(self.eagle_module.d2t)).to(self.eagle_module.d2t.device) - + self.eagle_module.d2t - ) - return full_logits[:, :, reverse_mapping] + device = full_logits.device + d2t = self.eagle_module.d2t.to(device) + reverse_mapping = torch.arange(d2t.numel(), device=device, dtype=d2t.dtype) + d2t + return full_logits.index_select(dim=2, index=reverse_mapping)examples/vlm_ptq/README.md (1)
39-45
: Fix NVFP4 support inconsistency for Qwen2.5‑VL.Support matrix marks NVFP4 as unsupported (❌), but the HF example advertises
--quant nvfp4
. Please align one of them.If unsupported, apply:
- scripts/huggingface_example.sh --type qwen --model Qwen2.5-VL-7B-Instruct --export_fmt hf --quant [fp8|nvfp4|int8_sq|int4_awq|w4a8_awq] + scripts/huggingface_example.sh --type qwen --model Qwen2.5-VL-7B-Instruct --export_fmt hf --quant [fp8|int8_sq|int4_awq|w4a8_awq]Also applies to: 80-85
examples/speculative_decoding/main.py (3)
211-224
: Callback uses processing_class key; may not exist in HF callbacks.HF usually passes tokenizer, not processing_class; this can KeyError. Prefer a safe fallback.
Apply:
- ars = validate_ar( - model=kwargs["model"], - tokenizer=kwargs["processing_class"], + ars = validate_ar( + model=kwargs["model"], + tokenizer=kwargs.get("tokenizer") or kwargs.get("processing_class"), ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"], device=kwargs["model"].device, )Optionally cache the dataset in init to avoid re-loading each interval.
211-214
: Default mismatch: ARValidationCallback(500) vs TrainingArguments default (1000).Not harmful since you pass the value, but confusing; align defaults.
- def __init__(self, ar_validate_steps: int = 500): + def __init__(self, ar_validate_steps: int = 1000):
236-236
: Avoid using Trainer._move_model_to_device (private API).Trainer handles device placement; calling a private method can break under FSDP/DeepSpeed.
- trainer._move_model_to_device(model, trainer.args.device) + # Rely on Trainer to handle device placementmodelopt/onnx/quantization/__main__.py (1)
181-189
: Behavior change: default high_precision_dtype now fp16 (was mode-dependent).This is a user-visible default change (e.g., INT8 used to keep fp32). Confirm docs/changelog call this out and consider emitting a runtime INFO when quantize_mode == "int8" and user didn't override.
Possible guard:
if args.quantize_mode == "int8" and not any(a.startswith("--high_precision_dtype") for a in sys.argv): print("INFO: defaulting high_precision_dtype=fp16; set --high_precision_dtype=fp32 to keep previous behavior.")tests/examples/speculative_decoding/test_eagle.py (1)
37-51
: Skip gracefully when no GPU to reduce CI flakiness.-def test_llama_eagle3(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_path): +import pytest +def test_llama_eagle3(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_path): + if num_gpus < 1: + pytest.skip("No GPU available")modelopt/onnx/quantization/quantize.py (1)
289-296
: Docstring tweak: clarify activations vs. weights conversion.Minor clarity: “weights and activations” not “weight and activation”.
- and the input model is of dtype fp32, model's weight and activation will be converted to - 'fp16' or 'bf16'. + and the input model is fp32, the model's weights and activations will be converted to + 'fp16' or 'bf16'.examples/vlm_ptq/scripts/huggingface_example.sh (1)
94-99
: Batch size 20 for qwen may OOM on smaller GPUs.Consider making this env/flag tunable or auto-scaling by GPU memory.
: "${BUILD_MAX_BATCH_SIZE:=$([ "$MODEL_TYPE" = "llava" ] || [ "$MODEL_TYPE" = "vila" ] || [ "$MODEL_TYPE" = "qwen" ] && echo 20 || echo 4)}"modelopt/torch/export/unified_export_hf.py (1)
513-519
: Gate saving hf_quant_config.json when no quantization is applied.Writing hf_quant_config.json even for QUANTIZATION_NONE creates confusing artifacts. Suggest saving only when a quant scheme is present.
Apply this diff:
- # NOTE: (hg) Should we save hf_quant_config when there's no quantization applied? - # Save hf_quant_config.json for backward compatibility - with open(f"{export_dir}/hf_quant_config.json", "w") as file: - json.dump(hf_quant_config, file, indent=4) + # Save hf_quant_config.json only if any quantization is applied (backward compatibility) + quant = hf_quant_config.get("quantization", {}) + if quant.get("quant_algo") or quant.get("kv_cache_quant_algo") != QUANTIZATION_NONE: + with open(f"{export_dir}/hf_quant_config.json", "w") as file: + json.dump(hf_quant_config, file, indent=4)examples/onnx_ptq/evaluate.py (2)
52-58
: CLI arg rename looks good; minor help text nit."all other modes have been deprecated in TensorRT" is broad; if you keep it, consider “deprecated here” to avoid implying upstream deprecations. No functional change needed.
85-87
: Use deterministic evaluation dataloader.Shuffling eval data isn’t typical and makes runs non‑reproducible.
Apply this diff:
- val_loader = torch.utils.data.DataLoader( - val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4 - ) + val_loader = torch.utils.data.DataLoader( + val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4 + )tests/examples/test_onnx_ptq.sh (1)
164-173
: Fix array matching for latency-only models (ShellCheck SC2199/SC2076).
[[ " ${arr[@]} " =~ " $item " ]]
is brittle. Use a loop to test membership.Apply this diff:
- if [[ " ${latency_models[@]} " =~ " $model_name " ]]; then + is_latency_model=false + for lm in "${latency_models[@]}"; do + if [[ "$lm" == "$model_name" ]]; then + is_latency_model=true + break + fi + done + if $is_latency_model; thenexamples/speculative_decoding/train_eagle3_and_export.sh (1)
49-55
: Guard against requesting more GPUs than available.If NUM_GPU exceeds available devices, CUDA_VISIBLE_DEVICES will reference non-existent IDs.
Apply this diff:
-if [[ "$NUM_GPU" == 1 ]]; then +avail="$(nvidia-smi --query-gpu=count --format=csv,noheader,nounits 2>/dev/null | head -n1)" +avail="${avail:-0}" +if [[ "$NUM_GPU" -gt "$avail" ]]; then + echo "Requested NUM_GPU=$NUM_GPU exceeds available GPUs ($avail)"; exit 1 +fi +if [[ "$NUM_GPU" == 1 ]]; then export CUDA_VISIBLE_DEVICES=0 else # Export as 0,1,...,N-1 for NUM_GPU GPUs devs="$(seq -s, 0 $((NUM_GPU-1)))" export CUDA_VISIBLE_DEVICES="$devs" fiexamples/speculative_decoding/launch_train.sh (3)
77-80
: Improve error message for invalid args.Print the actual flag that was invalid, not the post-‘=’ substring.
Apply this diff:
- >&2 printf "Error: Invalid argument ${1#*=}\n" + >&2 printf "Error: Invalid argument: %s\n" "$1"
88-91
: Avoid divide-by-zero when no GPUs are visible.GPU_COUNT can be 0 on CPU-only envs; protect DEFAULT_SAVE_STEPS.
Apply this diff:
-GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())") -# Calculate save_steps -DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT)) +GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())") +GPU_COUNT=$(( GPU_COUNT > 0 ? GPU_COUNT : 1 )) +DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT))
112-116
: Validate EAGLE config path when provided.Early fail produces clearer errors.
Apply this diff:
if [[ -n "$EAGLE_CONFIG" ]]; then - SPECULATIVE_ARGS="--eagle_config $EAGLE_CONFIG" + if [[ ! -f "$EAGLE_CONFIG" ]]; then + echo "eagle_config not found: $EAGLE_CONFIG"; exit 1 + fi + SPECULATIVE_ARGS="--eagle_config $EAGLE_CONFIG"tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py (5)
35-39
: Beware: mutating shared default configs across tests.ALGO_TO_CONFIG points to module-level defaults; the code below mutates nested fields, risking cross‑test leakage. Clone before mutation.
Apply this diff:
- mtsp_config = ALGO_TO_CONFIG[algo] + from copy import deepcopy + mtsp_config = deepcopy(ALGO_TO_CONFIG[algo])If prior guidance asserts these are already deepcopied per access, please confirm; otherwise prefer defensive deepcopy here.
88-89
: Tighten error message formatting.Use f-string or remove braces for clarity.
Apply this diff:
- raise ValueError("Only algo={eagle1, eagle3, medusa} are supported!") + raise ValueError("Only algo in {eagle1, eagle3, medusa} are supported!")
123-124
: Use integer division for shape checks.Avoid float comparison by using //.
Apply this diff:
- assert logits.shape[2] == vocab_size / size + assert logits.shape[2] == vocab_size // size
95-101
: Typo in comment (non-functional).“extrat” → “extract”.
Apply this diff:
- # Eagle3 last layer has a forward hook to extrat the pre_norm hidden_state + # Eagle3 last layer has a forward hook to extract the pre_norm hidden_state
159-165
: Dead code: 'algo == "eagle"' branch never hit.Parametrization doesn’t include "eagle" anymore. Remove or update the skip logic.
- if algo == "eagle": - try: - import megatron.core.post_training # noqa: F401 - except ImportError: - pytest.skip("megatron.core.post_training not found") + # If specific dependencies are required for eagle variants, add checks for "eagle1"/"eagle3" here if needed.examples/speculative_decoding/README.md (14)
5-10
: Define α and γ once, and consider ASCII fallbacks.
A one‑liner clarifying α=accepted tokens per step and γ=draft length avoids ambiguity in downstream sections and helps readers who can’t render Unicode.Apply this diff:
-Speculative decoding accelerates auto-regressive generation in large language models (LLMs) by leveraging a lightweight draft model to predict the next γ tokens. The main LLM then verifies these candidate tokens in a single forward pass. If the draft model correctly predicts α tokens, the LLM can accept and generate α+1 tokens per verification step, significantly improving generation speed. +Speculative decoding accelerates auto‑regressive generation by using a lightweight draft model to predict the next γ (gamma) tokens. The main LLM verifies these candidates in one forward pass. If the draft model is correct for α (alpha) tokens, the LLM accepts and generates α+1 tokens per step, improving throughput.
15-22
: Table wording polish and consistency.
- “Pre‑Requisites” → “Prerequisites”.
- Capitalization: “EAGLE model” → “EAGLE model” is fine; keep section names in Title Case.
Apply this diff:
-| Pre-Requisites | Required & optional dependencies | \[[Link](#pre-requisites)\] | +| Prerequisites | Required & optional dependencies | \[[Link](#prerequisites)\] | -| Simplified Workflow | Train, evaluate, and export eagle model with one-line command | \[[Link](#getting-started-simplified-workflow)\] | +| Simplified Workflow | Train, evaluate, and export an EAGLE model with a one‑line command | \[[Link](#getting-started-simplified-workflow)\] |And change the section header at Line 26 accordingly:
-## Pre-Requisites +## Prerequisites
28-39
: Installation placeholders and dataset fetch details need to be actionable.
- Replace
pip install -e ...
with the actual path/package, or show both editable‑install and wheel options.- Cloning HF datasets via git requires git‑lfs; provide an alternative using datasets.load_dataset to avoid LFS issues.
Apply this diff:
-Install Modelopt with `hf` dependencies and other requirements for this example: +Install ModelOpt with Hugging Face dependencies and other requirements for this example: ```bash -pip install -e ... +pip install -e .[hf] pip install -r requirements.txt-We use Daring-Anteater dataset in this example. Download by:
+We use the Daring‑Anteater dataset. You can either clone with git‑lfs:-git clone https://huggingface.co/datasets/nvidia/Daring-Anteater +git lfs install +git clone https://huggingface.co/datasets/nvidia/Daring-Anteater
+Or programmatically download with the datasets library inside your script:
+
+python +from datasets import load_dataset +ds = load_dataset("nvidia/Daring-Anteater", split="train") +
--- `93-99`: **Link target inconsistency for default configs.** Earlier you link to `eagle/default_config.py#L18`; here you link to `speculative/config.py#L37`. Use one stable link (without line anchors) to avoid rot. Apply this diff: ```diff -For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. +For EAGLE‑1 and EAGLE‑3 we provide a default model architecture config in ModelOpt ([default_config.py](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/speculative/eagle/default_config.py)).
106-111
: Hugging Face model load snippet: add tokenizer and dtype/device hints.
Minimal, but readers often copy‑paste. Adding tokenizer load and torch dtype makes it runnable.Apply this diff:
-model = transformers.AutoModelForCausalLM.from_pretrained( - "<path to your pretrained model>" -) +tokenizer = transformers.AutoTokenizer.from_pretrained("<base model or path>", use_fast=True) +model = transformers.AutoModelForCausalLM.from_pretrained( + "<base model or path>", torch_dtype="auto", device_map="auto" +)
136-138
: Show imports formtsp.convert
.
Without imports,mtsp
is undefined. Add a one‑liner import.Apply this diff:
-```python -mtsp.convert(model, [("eagle", config)]) -``` +```python +from modelopt.torch import speculative as mtsp +mtsp.convert(model, [("eagle", config)]) +```
157-164
: Script invocation LGTM, add line-break alignment nit.
The multi‑line backslash formatting is readable; keep comments spaced by two spaces after#
.
170-175
: Validation command LGTM.
Clear and minimal. Consider noting expected output (acceptance rate summary) for quick sanity checks.
180-185
: Export step LGTM; add note on target format.
Specify whether the export produces a Hugging Face‑compatible directory, a safetensors file, or TRT‑LLM artifacts.
229-236
: Support matrix naming consistency and scope disclaimer.
- Normalize model names: “Llama 2”, “Llama 3/3.1”, “Qwen 1.5/2/2.5”.
- Add a note that support depends on upstream ecosystem versions and is subject to change.
Apply this diff:
-| LLAMA 2 | ✅ | ✅ | ✅ | -| LLAMA 3, 3.1 | ✅ | ✅ | ✅ | +| Llama 2 | ✅ | ✅ | ✅ | +| Llama 3, 3.1 | ✅ | ✅ | ✅ | ... -| QWen 1.5,2,2.5 | ✅ | ✅ | ✅ | +| Qwen 1.5/2/2.5 | ✅ | ✅ | ✅ |And append under the table:
+Note: Support may vary by framework/runtime versions and will evolve over time. Refer to the release notes for the most up‑to‑date matrix.
239-241
: Checkpoint collection link LGTM.
Consider adding a note about licenses and usage terms for individual checkpoints.
64-70
: Update vLLM example: keep --api-key, use env var, and correct quantization note
- vLLM's
serve
exposes--api-key
— do not remove it. Replace the hardcoded token with an env-var pattern and show passing it to the CLI, e.g.:
export VLLM_API_KEY="token-abc123"
vllm serve meta-llama/Llama-3.2-1B-Instruct --api-key "$VLLM_API_KEY" --port 8000 --tensor-parallel-size 1- Correct the quantization note: vLLM supports
--quantization
(includingmodelopt
) and often auto-detects quantized models viahf_quant_config.json
; update wording to reflect this and advise verifying the exact flag for the vLLM version used.File: examples/speculative_decoding/README.md (lines 64–70)
142-153
: Avoid private Trainer API; import ModelOpt checkpoint helper from modelopt.torch.opt.
- Remove trainer._move_model_to_device(...) — it's a private Trainer helper; replace with model.to(trainer.args.device) or rely on Trainer's device-placement (TrainingArguments) so examples don't call a private API. (huggingface.co)
- The HF-checkpoint helper is modelopt.torch.opt.enable_huggingface_checkpointing — use either
import modelopt.torch.opt as mto; mto.enable_huggingface_checkpointing()
orfrom modelopt.torch.opt import enable_huggingface_checkpointing; enable_huggingface_checkpointing()
. Do not use modelopt.torch.export for this function. (nvidia.github.io)- Location: examples/speculative_decoding/README.md (lines 142–153). (github.com)
41-53
: Simplified workflow example is valid — optional clarity tweak recommended.
train_eagle3_and_export.sh parses --data (default Daring-Anteater/train.jsonl), --num_gpu, and auto-creates OUTPUT_DIR, so the README one-liner works as-is; optionally add explicit --data and --output_dir in examples/speculative_decoding/README.md to make inputs/outputs clearer.
"rope_scaling": { | ||
"factor": 32.0, | ||
"low_freq_factor": 1.0, | ||
"high_freq_factor": 4.0, | ||
"original_max_position_embeddings": 8192, | ||
"rope_type": "llama3" | ||
}, | ||
"initializer_range": 0.02 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Validate removal of draft_vocab_size and new rope_scaling fields usage.
If any code still reads draft_vocab_size
, this change could break speculative decoding config parsing. Please verify references and update defaults.
[suggest_minor_issue]
Run to check for references:
🏁 Script executed:
#!/bin/bash
fd -HI --strip-cwd-prefix -t f '.*' | xargs rg -nP 'draft_vocab_size\b' || true
Length of output: 5961
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "Searching for rope_scaling and related keys..."
fd -HI --strip-cwd-prefix -t f '.*' | xargs rg -nP 'rope_scaling\b|rope_type\b|low_freq_factor\b|high_freq_factor\b|original_max_position_embeddings\b' || true
echo
echo "Searching for eagle_architecture_config usages (context +-2 lines)..."
fd -HI --strip-cwd-prefix -t f '.*' | xargs rg -nP -n -C2 'eagle_architecture_config\b' || true
Length of output: 18106
Restore draft_vocab_size or update all references and defaults
draft_vocab_size is still referenced across the codebase and tests — removing it from configs will break parsing/initialization. Update or restore the key and align defaults/exports/tests.
Key locations to change:
- modelopt/torch/speculative/eagle/default_config.py
- modelopt/torch/speculative/plugins/transformers.py
- modelopt/torch/speculative/plugins/megatron_eagle.py
- modelopt/torch/speculative/config.py
- modelopt/torch/export/unified_export_megatron.py
- modelopt/torch/export/plugins/hf_spec_export.py
- examples/speculative_decoding/main.py
- examples/speculative_decoding/calibrate_draft_vocab.py
- tests/unit/torch/speculative/plugins/test_hf_speculative.py
- tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py
Fix options: (A) keep "draft_vocab_size" in example eagle_config.json/defaults (recommended), or (B) remove it everywhere and update parsing, exporters, and tests to use the new rope_scaling fields.
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description="Export a HF checkpoint (with ModelOpt state) for deployment." | ||
) | ||
parser.add_argument("--model_path", type=str, default="Path of the trained checkpoint.") | ||
parser.add_argument( | ||
"--export_path", type=str, default="Destination directory for exported files." | ||
) | ||
return parser.parse_args() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix argparse defaults and missing help; mark args required.
Current defaults are descriptive strings, leading to load failures when args aren’t passed.
Apply this diff:
-def parse_args():
+def parse_args():
parser = argparse.ArgumentParser(
description="Export a HF checkpoint (with ModelOpt state) for deployment."
)
- parser.add_argument("--model_path", type=str, default="Path of the trained checkpoint.")
- parser.add_argument(
- "--export_path", type=str, default="Destination directory for exported files."
- )
+ parser.add_argument("--model_path", type=str, required=True,
+ help="Path to the trained checkpoint directory.")
+ parser.add_argument("--export_path", type=str, required=True,
+ help="Destination directory for exported files.")
return parser.parse_args()
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def parse_args(): | |
parser = argparse.ArgumentParser( | |
description="Export a HF checkpoint (with ModelOpt state) for deployment." | |
) | |
parser.add_argument("--model_path", type=str, default="Path of the trained checkpoint.") | |
parser.add_argument( | |
"--export_path", type=str, default="Destination directory for exported files." | |
) | |
return parser.parse_args() | |
def parse_args(): | |
parser = argparse.ArgumentParser( | |
description="Export a HF checkpoint (with ModelOpt state) for deployment." | |
) | |
parser.add_argument("--model_path", type=str, required=True, | |
help="Path to the trained checkpoint directory.") | |
parser.add_argument("--export_path", type=str, required=True, | |
help="Destination directory for exported files.") | |
return parser.parse_args() |
🤖 Prompt for AI Agents
In examples/speculative_decoding/export_hf_checkpoint.py around lines 27 to 35,
the argparse arguments use descriptive strings as defaults which causes load
failures and lack help text; change each parser.add_argument to remove the
descriptive default, set required=True, and provide a proper help="..." string
(e.g., help="Path of the trained checkpoint." for --model_path and
help="Destination directory for exported files." for --export_path) so the args
are mandatory and have correct help guidance.
try: | ||
import wandb | ||
|
||
wandb.init() | ||
except ImportError: | ||
wandb = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't init wandb at import time; broaden error handling.
Module-level init causes side effects, multi-process duplication, and uncaught init errors (non-ImportError). Move init into train(), gate on main process, and tolerate failures.
Apply this diff here and add in-train init per snippet below:
-try:
- import wandb
-
- wandb.init()
-except ImportError:
- wandb = None
+wandb = None
Add inside train() after parsing args (use global to keep callback check working):
# inside train(), after args parsing
global wandb
if os.environ.get("WANDB_DISABLED", "").lower() not in ("1", "true", "yes"):
try:
import wandb as _wandb
if getattr(training_args, "local_rank", -1) in (-1, 0):
_wandb.init() # optionally project/name from args/env
wandb = _wandb
except Exception as e:
print_rank_0(f"wandb disabled: {e}")
wandb = None
🤖 Prompt for AI Agents
In examples/speculative_decoding/main.py around lines 50-55, the module
currently imports and calls wandb.init() at import time and only catches
ImportError; move the side-effectful initialization into train() (after args
parsing) and replace the module-level init with a simple import-or-None
assignment (wandb = None or import without init). Inside train(), use global
wandb, check WANDB_DISABLED env var, attempt to import wandb as a temporary name
(e.g., _wandb), and only call _wandb.init() when on the main process
(training_args.local_rank in (-1,0)); catch broad exceptions during import/init,
print a rank-0 message on failure, and set wandb = None so training proceeds
without crashing.
if wandb: | ||
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gate wandb.log to main process to avoid duplicate logs.
All ranks invoke callbacks; without gating you'll spam W&B.
- if wandb:
- wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
+ if wandb and getattr(args, "local_rank", -1) in (-1, 0):
+ wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if wandb: | |
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step) | |
if wandb and getattr(args, "local_rank", -1) in (-1, 0): | |
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step) |
🤖 Prompt for AI Agents
In examples/speculative_decoding/main.py around lines 225-226, wandb.log is
called from every rank which causes duplicate logs; gate the call so only the
main process logs (e.g., wrap wandb.log with a check for main process such as if
accelerator.is_main_process: or if dist.get_rank() == 0 / rank == 0), ensuring
that the main-process check is available/imported from your launcher
(Accelerator or torch.distributed) and leave all other ranks skipping the
wandb.log call.
We can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set: | ||
|
||
```bash | ||
python calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data Daring-Anteater/train.jsonl --draft_vocab_size 32000 --save_dir draft_vocab_cache | ||
``` | ||
|
||
`/scripts/` is the absolute path to `modelopt/examples/speculative_decoding` which contains `server_generate.py` and `distributed_generate`. | ||
This will launch a vllm server (sglang is also available) on each node. Each node will work through 10 shards of data (10\*max_lines_per_shard number of samples). | ||
In this case, the first 40 shards of data will be processed. | ||
To process the next 40 shards | ||
This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
d2t mapping formula is likely incorrect.
A draft→target token mapping generally isn’t an additive offset; it’s an index map. The expression target_token = draft_token + d2t[draft_token]
risks invalid IDs. Prefer direct lookup target_token = d2t[draft_token]
.
Apply this diff:
-This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`.
+This produces a `d2t.pt` file in `save_dir`, which maps draft token IDs to target token IDs. During inference, map with `target_token = d2t[draft_token]`.
🤖 Prompt for AI Agents
In examples/speculative_decoding/README.md around lines 83–90, the documented
draft→target mapping formula is incorrect; replace the additive formula with a
direct index lookup. Update the text to say that d2t is an index map and during
inference you obtain the target token by looking up the draft token in d2t
(target token = the value at d2t[draft_token]) rather than adding an offset, and
ensure the example command and resulting file description remain unchanged.
"--max_tokens", type=int, default=2048, help="Maximum number of tokens to generate" | ||
) | ||
parser.add_argument("--chat", action="store_true", help="Use chat mode") | ||
parser.add_argument("--chat", default=True, type=bool, help="Use chat mode") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix argparse boolean parsing for --chat (current code always evaluates to True).
type=bool
on argparse treats any non-empty string as True, so --chat False
still enables chat. Use BooleanOptionalAction (Python ≥3.9) or paired flags.
Apply this diff:
-parser.add_argument("--chat", default=True, type=bool, help="Use chat mode")
+parser.add_argument(
+ "--chat",
+ action=argparse.BooleanOptionalAction,
+ default=True,
+ help="Use chat mode",
+)
If older Python needs support, I can wire a str2bool
helper or add --no-chat
.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
parser.add_argument("--chat", default=True, type=bool, help="Use chat mode") | |
parser.add_argument( | |
"--chat", | |
action=argparse.BooleanOptionalAction, | |
default=True, | |
help="Use chat mode", | |
) |
🤖 Prompt for AI Agents
In examples/speculative_decoding/server_generate.py around line 49, the argparse
argument parser uses type=bool which makes any non-empty string evaluate True so
"--chat False" still enables chat; replace this with a proper boolean flag:
either use argparse.BooleanOptionalAction (Python ≥3.9) by changing the
add_argument call to use action=argparse.BooleanOptionalAction and remove
type/default pairing, or implement a str2bool helper and set type=str2bool with
default=True, or add paired flags (--chat / --no-chat) to correctly parse false
values; update help text accordingly.
use_torch_autocast = not ( | ||
is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32" | ||
) | ||
autocast = torch.autocast("cuda") if use_torch_autocast else nullcontext() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Include INT4 in autocast gating to avoid unintended mixed precision during export
Autocast remains enabled for INT4-quantized models, which can inject casts during tracing/export and skew ONNX dtypes. Align with FP4/MXFP8 behavior by disabling autocast for INT4 as well.
Apply:
- use_torch_autocast = not (
- is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32"
- )
+ use_torch_autocast = not (
+ is_fp4_quantized(model)
+ or is_mxfp8_quantized(model)
+ or is_int4_quantized(model)
+ or weights_dtype == "fp32"
+ )
Optional guard to prevent crashes on CPU-only hosts:
- autocast = torch.autocast("cuda") if use_torch_autocast else nullcontext()
+ autocast = (
+ torch.autocast("cuda") if use_torch_autocast and torch.cuda.is_available() else nullcontext()
+ )
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
use_torch_autocast = not ( | |
is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32" | |
) | |
autocast = torch.autocast("cuda") if use_torch_autocast else nullcontext() | |
use_torch_autocast = not ( | |
is_fp4_quantized(model) | |
or is_mxfp8_quantized(model) | |
or is_int4_quantized(model) | |
or weights_dtype == "fp32" | |
) | |
autocast = ( | |
torch.autocast("cuda") if use_torch_autocast and torch.cuda.is_available() else nullcontext() | |
) |
🤖 Prompt for AI Agents
In modelopt/torch/_deploy/utils/torch_onnx.py around lines 396 to 400, the
autocast gating does not include INT4 quantized models so autocast can remain
enabled for INT4 and introduce unwanted casts during export; update the
use_torch_autocast condition to also check for INT4 (e.g., add
is_int4_quantized(model) or equivalent) so autocast is disabled for INT4 like
FP4/MXFP8, and when creating autocast only use torch.autocast("cuda") if CUDA is
available (torch.cuda.is_available()) otherwise fall back to nullcontext() to
avoid crashes on CPU-only hosts.
05_train-->07_export_hf; | ||
``` | ||
|
||
## Results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@omrialmog I added QAT results here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (2)
examples/nemo_run/qat/README.md (2)
19-29
: Flow stage names now match exact task IDs — LGTM.Stages align with logs/ADVANCED.md naming (e.g., 00_openscience_data).
51-52
: Move “Usage” right after “Overview.”Previous feedback requested this; it’s still after “Results.” Reordering improves UX.
🧹 Nitpick comments (5)
examples/nemo_run/qat/README.md (5)
5-7
: Consolidate duplicate ADVANCED.md links.Both “Slurm Examples” and “Advanced Topics” point to the same ADVANCED.md. Collapse into one, or point them to distinct anchors.
Apply this diff:
-[Slurm Examples](ADVANCED.md) | -[Advanced Topics](ADVANCED.md) | +[Advanced Topics (incl. Slurm)](ADVANCED.md) |
41-50
: Add reproducibility metadata to results.Include seed, NeMo container tag, NeMo commit, TE/version, and exact eval subset definition (which “5%” slice, shuffled/fixed). Improves comparability of MMLU numbers.
55-55
: Nit: extra space and wording.“…either locally or…” → single space; also prefer “on a Slurm cluster (see ADVANCED.md)”.
Apply this diff:
-You can run the example either locally or on a [Slurm cluster](ADVANCED.md). +You can run the example either locally or on a [Slurm cluster](ADVANCED.md).
76-82
: Make the working directory explicit in the example.Reduce ambiguity by including a cd.
Apply this diff:
-```bash -python qat/nemo_qat_flow.py --log-dir /my/log/dir --experiment qat_experiment -``` +```bash +cd examples/nemo_run +python qat/nemo_qat_flow.py --log-dir /my/log/dir --experiment qat_experiment +```
88-94
: Mirror the explicit cd in QAD example.Consistency helps copy‑paste.
Apply this diff:
-```bash -python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment -``` +```bash +cd examples/nemo_run +python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment +```
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/nemo_run/qat/README.md
(2 hunks)modelopt/torch/export/plugins/nemo_run.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/export/plugins/nemo_run.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (3)
examples/nemo_run/qat/README.md (3)
30-39
: Mermaid graph matches the 8-stage flow — LGTM.Edges reflect the described execution order.
98-104
: Defaults and HW notes — LGTM.Clear default (Qwen3-8B/qwen3_8b) and GPU/node requirements.
107-108
: Confirmed: CLI flags exist as documented.Found: --tensor_parallelism and --pipeline_parallelism (examples/nemo_run/qat/nemo_qat_flow.py; examples/nemo_run/common/in_memory_mmlu.py); --ptq-gpus, --train-gpus, --train-nodes (examples/nemo_run/qat/nemo_qat_flow.py).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (4)
examples/nemo_run/qat/nemo_qat_flow.py (3)
138-145
: Fix KV‑cache tri‑state; don’t force disable by default.Current CLI always passes --disable_kv_cache when the user doesn’t opt in. Make it an explicit tri‑state and build the flag only if set. (Matches prior feedback.)
@@ - parser.add_argument( - "--enable_kv_cache", - help="Enables KV-cache quantization", - action="store_true", - default=False, - ) + # Tri-state KV cache flag: None (unspecified), True (enable), False (disable) + kv_group = parser.add_mutually_exclusive_group() + kv_group.add_argument( + "--enable_kv_cache", + dest="enable_kv_cache", + help="Enable KV-cache quantization", + action="store_true", + ) + kv_group.add_argument( + "--disable_kv_cache", + dest="enable_kv_cache", + help="Disable KV-cache quantization", + action="store_false", + ) + parser.set_defaults(enable_kv_cache=None) @@ - "--kv_cache_qformat", - args.kv_cache_qformat, - "--enable_kv_cache" if args.enable_kv_cache else "--disable_kv_cache", + "--kv_cache_qformat", + args.kv_cache_qformat, + *( + ["--enable_kv_cache"] if args.enable_kv_cache is True + else (["--disable_kv_cache"] if args.enable_kv_cache is False else []) + ),Also applies to: 196-201
148-156
: Guard globals used in main() for programmatic invocation.SEQUENCE_LENGTH, MBS, GBS, TRAIN_STEPS, VAL_INTERVAL are defined under main only; importing this module and calling main(args) will raise NameError. (Matches prior feedback.)
def main(args): + # Fallbacks if constants aren’t initialized (e.g., when called programmatically) + global SEQUENCE_LENGTH, MBS, GBS, TRAIN_STEPS, VAL_INTERVAL + SEQUENCE_LENGTH = globals().get("SEQUENCE_LENGTH", 4096) + MBS = globals().get("MBS", 1) + GBS = globals().get("GBS", 512) + TRAIN_STEPS = globals().get("TRAIN_STEPS", 200) + VAL_INTERVAL = globals().get("VAL_INTERVAL", 50)Also applies to: 224-233, 242-249, 375-379
283-285
: Slurm multi‑node training bug: nodes not set on executor.Training will always request 1 node because nodes isn’t passed to create_slurm_executor. NeMo config alone won’t fix the Slurm allocation.
- train_gpu_executor = create_slurm_executor( - SLURM_CONFIG, num_gpus=args.train_gpus, ntasks_per_node=args.train_gpus - ) + train_gpu_executor = create_slurm_executor( + SLURM_CONFIG, + nodes=args.train_nodes, + num_gpus=args.train_gpus, + ntasks_per_node=args.train_gpus, + )examples/nemo_run/qat/README.md (1)
70-71
: Token env var and permissions guidance.Use the canonical env var and avoid chmod 777; suggest running as host UID and/or mounting a writable logs dir.
-You will also need to set your Huggingface token with `export HF_TOKEN=<your-token>`. You may also need to enable write access to the docker container to the `examples/nemo_run` folder by doing `chmod 777 nemo_run` so that logs can be written. +Set your Hugging Face token (HF_TOKEN is also recognized): +`export HUGGING_FACE_HUB_TOKEN=<your-token>` # or `export HF_TOKEN=<your-token>` + +For write access, avoid `chmod 777`. Run the container as your user (`-u $(id -u):$(id -g)`) and/or mount a writable logs directory (e.g., `-v /home/user/logs:/logs`) and pass `--log-dir /logs`.
🧹 Nitpick comments (3)
examples/nemo_run/qat/nemo_qat_flow.py (2)
278-287
: Avoid mutating the training executor for export; use a dedicated export executor.Changing ntasks_per_node on train_gpu_executor is brittle and can affect subsequent tasks. Create a dedicated export executor that requests all GPUs but runs a single task.
@@ - cpu_executor = create_slurm_executor(SLURM_CONFIG) + cpu_executor = create_slurm_executor(SLURM_CONFIG) ptq_gpu_executor = create_slurm_executor( SLURM_CONFIG, num_gpus=args.ptq_gpus, ntasks_per_node=args.ptq_gpus ) train_gpu_executor = create_slurm_executor( - SLURM_CONFIG, num_gpus=args.train_gpus, ntasks_per_node=args.train_gpus + SLURM_CONFIG, nodes=args.train_nodes, num_gpus=args.train_gpus, ntasks_per_node=args.train_gpus ) - single_gpu_executor = create_slurm_executor(SLURM_CONFIG, num_gpus=1, ntasks_per_node=1) + single_gpu_executor = create_slurm_executor(SLURM_CONFIG, num_gpus=1, ntasks_per_node=1) + # Export needs visibility to all GPUs but only 1 task due to NeMo bug + export_gpu_executor = create_slurm_executor( + SLURM_CONFIG, num_gpus=args.train_gpus, ntasks_per_node=1 + ) @@ - cpu_executor = single_gpu_executor = run.LocalExecutor() + cpu_executor = single_gpu_executor = run.LocalExecutor() ptq_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.ptq_gpus) train_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.train_gpus) + export_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=1) @@ - # WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo - train_gpu_executor.ntasks_per_node = 1 # will throw error if more than 1 task during export - exp.add( + exp.add( export, tail_logs=True, name="07_export_hf", - executor=train_gpu_executor, + executor=export_gpu_executor, dependencies=[s5], )Also applies to: 288-291, 339-346
149-156
: Clarify naming to avoid confusion between recipe and HF model name.model_name is first a recipe (e.g., qwen3_8b) then potentially overwritten with HF basename; this risks mistakes. Use distinct variables (e.g., recipe_name and hf_model_basename).
examples/nemo_run/qat/README.md (1)
59-63
: Add editable install step after cloning to ensure mounted sources are importable.Without installs, Python may import preinstalled packages from the container.
To run the example locally, launch a NeMo container ... then mount it onto your docker container. - `git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer.git` - `git clone https://github.com/NVIDIA-NeMo/NeMo.git && cd NeMo && git checkout 676ed1a` + +Inside the container, install the mounted repos in editable mode (or set PYTHONPATH): + +```bash +python -m pip install -U pip +python -m pip install -e /opt/NeMo +python -m pip install -e /workspace/TensorRT-Model-Optimizer/modelopt +# Alternatively: +# export PYTHONPATH=/opt/NeMo:/workspace/TensorRT-Model-Optimizer/modelopt:$PYTHONPATH +```
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/nemo_run/qat/README.md
(2 hunks)examples/nemo_run/qat/nemo_qat_flow.py
(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/nemo_run/qat/nemo_qat_flow.py (3)
modelopt/torch/export/plugins/nemo_run.py (1)
export_most_recent_ckpt
(24-33)examples/nemo_run/common/utils.py (4)
SlurmConfig
(24-72)create_slurm_executor
(75-117)get_finetune_recipe
(120-123)read_chat_template
(126-128)examples/nemo_run/common/in_memory_mmlu.py (1)
parse_args
(24-45)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (1)
examples/nemo_run/qat/nemo_qat_flow.py (1)
272-276
: Verify finetuned_ckpt_dir value for MMLU SFT eval.Passing exp_dir may include multiple runs; if in_memory_mmlu.py expects a single experiment directory, prefer train.log.log_dir (or most recent subdir) for determinism.
- args=["--finetuned_ckpt_dir", exp_dir, "--tensor_parallelism", f"{args.ptq_gpus}"], + args=["--finetuned_ckpt_dir", train.log.log_dir, "--tensor_parallelism", f"{args.ptq_gpus}"],If multiple runs are present under log_dir, consider pointing to the most recent subdir explicitly.
16f61f1
to
246a765
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (6)
examples/nemo_run/qat/README.md (6)
59-63
: Mounted sources won’t be importable without installs; add editable installsAdd steps so Python resolves mounted NeMo/TMO sources inside the container.
To run the example locally, launch a [NeMo container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) with version 25.07 or higher. Clone the `TensorRT-Model-Optimizer` repository and `NeMo` repository (checkout a specific commit for NeMo), then mount it onto your docker container. - `git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer.git` - `git clone https://github.com/NVIDIA-NeMo/NeMo.git && cd NeMo && git checkout 676ed1a` + +Inside the container, install the mounted repos in editable mode (or set PYTHONPATH): + +```bash +python -m pip install -U pip +python -m pip install -e /opt/NeMo +python -m pip install -e /workspace/TensorRT-Model-Optimizer/modelopt +# Alternatively: +# export PYTHONPATH=/opt/NeMo:/workspace/TensorRT-Model-Optimizer/modelopt:$PYTHONPATH +```
66-68
: Avoid mounting into site-packages; run as non-root and use editable installsMounting over site-packages is brittle; prefer workspace mounts + pip -e.
-```bash -docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash -``` +```bash +docker run --rm -it --gpus=all --shm-size=20g \ + -u $(id -u):$(id -g) \ + -v /home/user/NeMo:/opt/NeMo \ + -v /home/user/TensorRT-Model-Optimizer:/workspace/TensorRT-Model-Optimizer \ + -v /home/user/logs:/logs \ + nvcr.io/nvidia/nemo:25.07 bash + +# Inside the container: +python -m pip install -U pip +python -m pip install -e /opt/NeMo +python -m pip install -e /workspace/TensorRT-Model-Optimizer/modelopt +```
70-71
: Use canonical HF token env var; don’t recommend chmod 777Prefer HUGGINGFACE_HUB_TOKEN (HF_TOKEN as alias). Avoid world-writable perms; run as host user and/or mount a writable logs dir.
-You will also need to set your Huggingface token with `export HF_TOKEN=<your-token>`. You may also need to enable write access to the docker container to the `examples/nemo_run` folder by doing `chmod 777 nemo_run` so that logs can be written. +Set your Hugging Face token (HF_TOKEN is also recognized): +`export HUGGINGFACE_HUB_TOKEN=<your-token>` # or `export HF_TOKEN=<your-token>` + +For write access, avoid `chmod 777`. Instead run the container as your user (`-u $(id -u):$(id -g)`) and/or mount a writable logs directory (e.g., `-v /home/user/logs:/logs`) and pass `--log-dir /logs`.
86-86
: CLI flag mismatch: use --enable_kv_cache (underscore)Docs should match the script; also use pipe in the choice list and mention disable flag.
-> **_NOTE:_** To enable KV cache quantization, add `--enable-kv-cache` and specify qformat using `--kv-cache-qformat <fp8, nvfp4>`. +> **_NOTE:_** To enable KV cache quantization, add `--enable_kv_cache` and specify qformat using `--kv-cache-qformat <fp8|nvfp4>`. To explicitly disable it, use `--disable_kv_cache`.
53-53
: Move “Usage” nearer the top (after Overview/Flow)Improves discoverability; this aligns with earlier feedback.
98-101
: Keep Slurm details in ADVANCED.md; link instead of inline flagsReduce duplication by referencing the Slurm doc here.
-Locally this script currently supports models that can be trained on 1 node with 8 x 80GB GPUs. On Slurm you can configure the number of nodes/gpus for training and PTQ with the following flags: `--train-nodes`, `--train-gpus`, `--ptq-gpus`. +Locally this script supports 1 node with 8 × 80GB GPUs. For Slurm configuration (nodes/GPUs for training and PTQ), see [ADVANCED.md](ADVANCED.md).
🧹 Nitpick comments (3)
examples/nemo_run/qat/README.md (3)
5-7
: Dedup header links; both point to ADVANCED.mdAvoid redundant links; keep one combined link.
-[Slurm Examples](ADVANCED.md) | -[Advanced Topics](ADVANCED.md) | +[Slurm & Advanced Topics](ADVANCED.md) | [NeMo Integration](https://github.com/NVIDIA-NeMo/NeMo/tree/main/nemo/collections/llm/modelopt)
51-51
: Grammar/consistency nitUse “is also” and consistent spacing.
-The resulting exported checkpoint also is much smaller in memory at 6.4GB compared to the original BF16 checkpoint which is 16.4 GB. +The resulting exported checkpoint is also much smaller in memory at 6.4 GB, compared to the original BF16 checkpoint at 16.4 GB.
57-57
: Extra spaceMinor spacing fix.
-You can run the example either locally or on a [Slurm cluster](ADVANCED.md). +You can run the example either locally or on a [Slurm cluster](ADVANCED.md).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
examples/llm_qat/README.md
(1 hunks)examples/nemo_run/common/in_memory_mmlu.py
(1 hunks)examples/nemo_run/common/process_openscience.py
(1 hunks)examples/nemo_run/common/utils.py
(1 hunks)examples/nemo_run/qat/ADVANCED.md
(1 hunks)examples/nemo_run/qat/README.md
(2 hunks)examples/nemo_run/qat/nemo_qat_flow.py
(6 hunks)modelopt/torch/export/plugins/nemo_run.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (7)
- examples/nemo_run/common/process_openscience.py
- examples/nemo_run/common/utils.py
- examples/nemo_run/common/in_memory_mmlu.py
- modelopt/torch/export/plugins/nemo_run.py
- examples/nemo_run/qat/nemo_qat_flow.py
- examples/llm_qat/README.md
- examples/nemo_run/qat/ADVANCED.md
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (2)
examples/nemo_run/qat/README.md (2)
59-63
: Pin/validate NeMo commit — confirmed
Commit 676ed1a exists in NVIDIA‑NeMo/NeMo and is compatible with the nvcr.io/nvidia/nemo:25.07 container; no update required.
109-110
: Verify CLI flags exist in nemo_qat_flow.pyAutomated search failed (missing "nemo" path); confirm
--tensor_parallelism
and--pipeline_parallelism
are defined as argparse CLI args in nemo_qat_flow.py, or update the README to point to the correct flags.
246a765
to
1249de1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (6)
examples/nemo_run/common/utils.py (2)
120-124
: Guard missing recipe module; avoid AttributeError.Use a safe getattr and raise a clear ValueError.
-def get_finetune_recipe(recipe_name: str): - if not hasattr(getattr(llm, recipe_name), "finetune_recipe"): - raise ValueError(f"Recipe {recipe_name} does not have a Fine-Tuning recipe") - return getattr(llm, recipe_name).finetune_recipe(peft_scheme=None) +def get_finetune_recipe(recipe_name: str): + recipe_mod = getattr(llm, recipe_name, None) + if recipe_mod is None or not hasattr(recipe_mod, "finetune_recipe"): + raise ValueError(f"Recipe {recipe_name} does not exist or lacks a fine-tuning recipe") + return recipe_mod.finetune_recipe(peft_scheme=None)
91-105
: Propagate env vars to Slurm jobs (GPU + CPU executors).
SlurmConfig.env_vars
is never applied; pass through to SlurmExecutor.return run.SlurmExecutor( account=slurm_cfg.account, partition=slurm_cfg.partition_gpu, ntasks_per_node=ntasks_per_node, gpus_per_node=num_gpus, nodes=nodes, tunnel=tunnel, container_image=slurm_cfg.container_image, container_mounts=slurm_cfg.container_mounts, time=slurm_cfg.time, packager=run.GitArchivePackager(), mem="0", gres=f"gpu:{num_gpus}", + env_vars=slurm_cfg.env_vars, ) @@ return run.SlurmExecutor( account=slurm_cfg.account, partition=slurm_cfg.partition_cpu, nodes=nodes, tunnel=tunnel, container_image=slurm_cfg.container_image, container_mounts=slurm_cfg.container_mounts, time=slurm_cfg.time, packager=run.GitArchivePackager(), mem="0", + env_vars=slurm_cfg.env_vars, )Run to confirm the constructor param name (env_vars/env) for your installed nemo-run:
#!/usr/bin/env bash python - << 'PY' import inspect import nemo_run as run sig = inspect.signature(run.SlurmExecutor) print(sig) print([p for p in sig.parameters]) PYAlso applies to: 107-117
examples/nemo_run/qat/nemo_qat_flow.py (4)
185-203
: Only pass KV‑cache flag when explicitly set.Build the flag conditionally; avoid hard “disable”.
- ptq = run.Script( + kv_cache_flag = ( + ["--enable_kv_cache"] if args.enable_kv_cache is True + else (["--disable_kv_cache"] if args.enable_kv_cache is False else []) + ) + ptq = run.Script( "/opt/NeMo/scripts/llm/ptq.py", args=[ "-nc", bf16_ckpt_path, "-out", ptq_model_out, "--export_format", "nemo", "--algorithm", args.algorithm, "--kv_cache_qformat", args.kv_cache_qformat, - "--enable_kv_cache" if args.enable_kv_cache else "--disable_kv_cache", + *kv_cache_flag, "-ctp", f"{args.ptq_gpus}", ], entrypoint="python", )
138-142
: Make KV‑cache flag tri‑state; don’t force disable by default.Expose both enable/disable flags and default to “unspecified”.
- parser.add_argument( - "--enable_kv_cache", - help="Enables KV-cache quantization", - action="store_true", - default=False, - ) + kv = parser.add_mutually_exclusive_group() + kv.add_argument("--enable_kv_cache", dest="enable_kv_cache", action="store_true", help="Enable KV-cache quantization") + kv.add_argument("--disable_kv_cache", dest="enable_kv_cache", action="store_false", help="Disable KV-cache quantization") + parser.set_defaults(enable_kv_cache=None)
148-156
: Guard globals and support programmatic invocation.Avoid NameError when called outside main; validate Slurm config presence.
def main(args): + # Fallbacks if constants/SLURM_CONFIG aren’t defined at import time + global SEQUENCE_LENGTH, MBS, GBS, TRAIN_STEPS, VAL_INTERVAL + SEQUENCE_LENGTH = globals().get("SEQUENCE_LENGTH", 4096) + MBS = globals().get("MBS", 1) + GBS = globals().get("GBS", 512) + TRAIN_STEPS = globals().get("TRAIN_STEPS", 200) + VAL_INTERVAL = globals().get("VAL_INTERVAL", 50) + if args.use_slurm and "SLURM_CONFIG" not in globals(): + raise ValueError("SLURM_CONFIG must be defined when --use-slurm is set (provide via __main__ or module scope).")
338-346
: Use dedicated single‑GPU executor for export; don’t mutate the train executor.Prevents stale Slurm settings and surprises in multi‑stage runs.
- # WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo - train_gpu_executor.ntasks_per_node = 1 # will throw error if more than 1 task during export - exp.add( - export, - tail_logs=True, - name="07_export_hf", - executor=train_gpu_executor, - dependencies=[s5], - ) + # WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo + exp.add( + export, + tail_logs=True, + name="07_export_hf", + executor=single_gpu_executor, + dependencies=[s5], + )
🧹 Nitpick comments (4)
examples/nemo_run/common/utils.py (2)
40-51
: (Optional) Validate Slurm time format early.Prevent misconfigurations by checking HH:MM:SS.
+import re @@ if not self.time: raise ValueError("SlurmConfig.time must be set to job time limit (e.g., '02:00:00')") + elif not re.fullmatch(r"\d{1,2}:\d{2}:\d{2}", self.time): + raise ValueError("SlurmConfig.time must be in HH:MM:SS format (e.g., '04:00:00')")
126-128
: Open chat template with UTF‑8 to avoid locale issues.-def read_chat_template(template_path: str): - with open(template_path) as f: +def read_chat_template(template_path: str): + with open(template_path, encoding="utf-8") as f: return f.read().strip()examples/nemo_run/qat/nemo_qat_flow.py (2)
151-153
: Handle missing/invalid recipe module gracefully.Avoid AttributeError; emit a clear message.
- model_name = args.finetune_recipe - model_module = getattr(llm, model_name) + model_name = args.finetune_recipe + model_module = getattr(llm, model_name, None) + if model_module is None: + raise ValueError(f"Unknown finetune recipe module: {model_name}")
264-276
: Tie MMLU TP/PP to user flags (optional).Use args.tensor_parallelism/pipeline_parallelism for eval to match train; reduces confusion when PTQ GPU count differs.
- eval_ptq = run.Script( + eval_ptq = run.Script( mmlu_script_path, - args=["--nemo_ckpt", ptq_model_out, "--tensor_parallelism", f"{args.ptq_gpus}"], + args=[ + "--nemo_ckpt", ptq_model_out, + "--tensor_parallelism", f"{args.tensor_parallelism}", + "--pipeline_parallelism", f"{args.pipeline_parallelism}", + ], entrypoint="python", ) - eval_bf16 = run.Script( + eval_bf16 = run.Script( mmlu_script_path, - args=["--nemo_ckpt", bf16_ckpt_path, "--tensor_parallelism", f"{args.ptq_gpus}"], + args=[ + "--nemo_ckpt", bf16_ckpt_path, + "--tensor_parallelism", f"{args.tensor_parallelism}", + "--pipeline_parallelism", f"{args.pipeline_parallelism}", + ], entrypoint="python", ) - eval_sft = run.Script( + eval_sft = run.Script( mmlu_script_path, - args=["--finetuned_ckpt_dir", exp_dir, "--tensor_parallelism", f"{args.ptq_gpus}"], + args=[ + "--finetuned_ckpt_dir", exp_dir, + "--tensor_parallelism", f"{args.tensor_parallelism}", + "--pipeline_parallelism", f"{args.pipeline_parallelism}", + ], entrypoint="python", )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
examples/llm_qat/README.md
(1 hunks)examples/nemo_run/common/in_memory_mmlu.py
(1 hunks)examples/nemo_run/common/process_openscience.py
(1 hunks)examples/nemo_run/common/utils.py
(1 hunks)examples/nemo_run/qat/ADVANCED.md
(1 hunks)examples/nemo_run/qat/README.md
(2 hunks)examples/nemo_run/qat/nemo_qat_flow.py
(6 hunks)modelopt/torch/export/plugins/nemo_run.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (6)
- examples/nemo_run/common/in_memory_mmlu.py
- examples/nemo_run/common/process_openscience.py
- examples/nemo_run/qat/ADVANCED.md
- examples/llm_qat/README.md
- modelopt/torch/export/plugins/nemo_run.py
- examples/nemo_run/qat/README.md
🧰 Additional context used
🧬 Code graph analysis (2)
examples/nemo_run/common/utils.py (1)
examples/nemo_run/common/process_lima.py (1)
download_hf_dataset
(29-37)
examples/nemo_run/qat/nemo_qat_flow.py (3)
modelopt/torch/export/plugins/nemo_run.py (1)
export_most_recent_ckpt
(24-33)examples/nemo_run/common/utils.py (4)
SlurmConfig
(24-72)create_slurm_executor
(75-117)get_finetune_recipe
(120-123)read_chat_template
(126-128)examples/nemo_run/common/in_memory_mmlu.py (1)
parse_args
(24-45)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-docs
- GitHub Check: code-quality
self.env_vars |= { | ||
"CUDA_DEVICE_MAX_CONNECTIONS": "1", # Disable GPU communication/computation overlap for performance | ||
"TRANSFORMERS_OFFLINE": "1", # Disable online downloads from HuggingFace | ||
"TORCH_NCCL_AVOID_RECORD_STREAMS": "1", # Disable caching NCCL communication buffer memory | ||
"NCCL_NVLS_ENABLE": "0", # Disable NVLink SHARP to save memory | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don’t override caller-provided env vars; reverse the merge.
self.env_vars |= {...}
overwrites user settings. Defaults should be baseline, user wins on conflicts.
- self.env_vars |= {
- "CUDA_DEVICE_MAX_CONNECTIONS": "1", # Disable GPU communication/computation overlap for performance
- "TRANSFORMERS_OFFLINE": "1", # Disable online downloads from HuggingFace
- "TORCH_NCCL_AVOID_RECORD_STREAMS": "1", # Disable caching NCCL communication buffer memory
- "NCCL_NVLS_ENABLE": "0", # Disable NVLink SHARP to save memory
- }
+ defaults = {
+ "CUDA_DEVICE_MAX_CONNECTIONS": "1", # Disable GPU communication/computation overlap for performance
+ "TRANSFORMERS_OFFLINE": "1", # Disable online downloads from HuggingFace
+ "TORCH_NCCL_AVOID_RECORD_STREAMS": "1", # Disable caching NCCL communication buffer memory
+ "NCCL_NVLS_ENABLE": "0", # Disable NVLink SHARP to save memory
+ }
+ # User-specified values take precedence
+ self.env_vars = defaults | self.env_vars
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
self.env_vars |= { | |
"CUDA_DEVICE_MAX_CONNECTIONS": "1", # Disable GPU communication/computation overlap for performance | |
"TRANSFORMERS_OFFLINE": "1", # Disable online downloads from HuggingFace | |
"TORCH_NCCL_AVOID_RECORD_STREAMS": "1", # Disable caching NCCL communication buffer memory | |
"NCCL_NVLS_ENABLE": "0", # Disable NVLink SHARP to save memory | |
} | |
defaults = { | |
"CUDA_DEVICE_MAX_CONNECTIONS": "1", # Disable GPU communication/computation overlap for performance | |
"TRANSFORMERS_OFFLINE": "1", # Disable online downloads from HuggingFace | |
"TORCH_NCCL_AVOID_RECORD_STREAMS": "1", # Disable caching NCCL communication buffer memory | |
"NCCL_NVLS_ENABLE": "0", # Disable NVLink SHARP to save memory | |
} | |
# User-specified values take precedence | |
self.env_vars = defaults | self.env_vars |
🤖 Prompt for AI Agents
In examples/nemo_run/common/utils.py around lines 67 to 72, the current use of
self.env_vars |= {...} forces these defaults to override any caller-provided
environment variables; change the merge order so caller values win. Create a
defaults dict with those keys, then merge so existing self.env_vars take
precedence (e.g., apply defaults first and then overlay self.env_vars, or set
each default only if the key is not already present). Update the code to perform
a non-destructive merge rather than overwriting user-supplied entries.
Signed-off-by: Jennifer Chen <[email protected]>
1249de1
to
410c7dd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/nemo_run/qat/nemo_qat_flow.py (1)
186-201
: Fix PTQ CLI flags (NeMo 25.07)
- File: examples/nemo_run/qat/nemo_qat_flow.py (lines 186–201) — the flags used are incorrect for NeMo 25.07 and will break runs. Replace the KV-cache flags and the short
-ctp
alias with the canonical names: e.g. enable FP8 KV-cache with--use_fp8_kv_cache
(or disable paged KV cache with--no_paged_kv_cache
as appropriate), and replace-ctp
with the appropriate full TP flag (--calibration_tp
,--calibration_pp
or--inference_tp
) depending on intent. Update accordingly before running.
♻️ Duplicate comments (11)
examples/nemo_run/qat/README.md (4)
53-57
: Move “Usage” directly after “Overview”Improves discoverability; align with prior review.
66-69
: Do not mount into site-packages or run as root; use editable installs and user UID/GIDReplace the docker run example with a safer pattern; also add editable installs for NeMo and modelopt inside the container.
-```bash -docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash -``` +```bash +docker run --rm -it --gpus=all --shm-size=20g \ + -u $(id -u):$(id -g) \ + -v /home/user/NeMo:/opt/NeMo \ + -v /home/user/TensorRT-Model-Optimizer:/workspace/TensorRT-Model-Optimizer \ + -v /home/user/logs:/logs \ + nvcr.io/nvidia/nemo:25.07 bash + +# Inside the container, install mounted repos in editable mode: +python -m pip install -U pip +python -m pip install -e /opt/NeMo +python -m pip install -e /workspace/TensorRT-Model-Optimizer/modelopt +```
70-70
: Use HUGGING_FACE_HUB_TOKEN and avoid chmod 777; suggest safer write-access guidanceUpdate token env var and replace the 777 advice with running as host UID and/or mounting a writable logs dir.
-You will also need to set your Huggingface token with `export HF_TOKEN=<your-token>`. You may also need to enable write access to the docker container to the `examples/nemo_run` folder by doing `chmod 777 nemo_run` so that logs can be written. +Set your Hugging Face token (HF_TOKEN is also recognized): +`export HUGGING_FACE_HUB_TOKEN=<your-token>` # or `export HF_TOKEN=<your-token>` + +For write access, avoid `chmod 777`. Run the container as your user (`-u $(id -u):$(id -g)`) and/or mount a writable logs directory (e.g., `-v /home/user/logs:/logs`) and pass `--log-dir /logs`.
86-86
: Flag name mismatch: use --enable_kv_cache (underscore); also document --disable_kv_cacheThe script defines underscore style; current docs use hyphens.
-> **_NOTE:_** To enable KV cache quantization, add `--enable-kv-cache` and specify qformat using `--kv-cache-qformat <fp8, nvfp4>`. +> **_NOTE:_** To enable KV cache quantization, add `--enable_kv_cache` and specify qformat using `--kv-cache-qformat <fp8|nvfp4>`. To explicitly disable it, use `--disable_kv_cache`.examples/nemo_run/common/utils.py (3)
120-124
: Guard missing recipe module to avoid AttributeErrorUse a safe getattr and clearer error.
-def get_finetune_recipe(recipe_name: str): - if not hasattr(getattr(llm, recipe_name), "finetune_recipe"): - raise ValueError(f"Recipe {recipe_name} does not have a Fine-Tuning recipe") - return getattr(llm, recipe_name).finetune_recipe(peft_scheme=None) +def get_finetune_recipe(recipe_name: str): + recipe_mod = getattr(llm, recipe_name, None) + if recipe_mod is None or not hasattr(recipe_mod, "finetune_recipe"): + raise ValueError(f"Recipe {recipe_name} does not exist or lacks a fine-tuning recipe") + return recipe_mod.finetune_recipe(peft_scheme=None)
67-72
: Don’t overwrite user env vars; merge defaults so user winsCurrent
self.env_vars |= {...}
overrides caller-provided values. Reverse the merge.- self.env_vars |= { - "CUDA_DEVICE_MAX_CONNECTIONS": "1", # Disable GPU communication/computation overlap for performance - "TRANSFORMERS_OFFLINE": "1", # Disable online downloads from HuggingFace - "TORCH_NCCL_AVOID_RECORD_STREAMS": "1", # Disable caching NCCL communication buffer memory - "NCCL_NVLS_ENABLE": "0", # Disable NVLink SHARP to save memory - } + defaults = { + "CUDA_DEVICE_MAX_CONNECTIONS": "1", # Disable GPU communication/computation overlap for performance + "TRANSFORMERS_OFFLINE": "1", # Disable online downloads from HuggingFace + "TORCH_NCCL_AVOID_RECORD_STREAMS": "1", # Disable caching NCCL communication buffer memory + "NCCL_NVLS_ENABLE": "0", # Disable NVLink SHARP to save memory + } + # User-specified values take precedence + self.env_vars = defaults | self.env_vars
91-106
: Pass env vars into SlurmExecutor
SlurmConfig.env_vars
is never applied; forward it to the executor.return run.SlurmExecutor( account=slurm_cfg.account, partition=slurm_cfg.partition_gpu, ntasks_per_node=ntasks_per_node, gpus_per_node=num_gpus, nodes=nodes, tunnel=tunnel, container_image=slurm_cfg.container_image, container_mounts=slurm_cfg.container_mounts, time=slurm_cfg.time, packager=run.GitArchivePackager(), mem="0", gres=f"gpu:{num_gpus}", + env_vars=slurm_cfg.env_vars, ) @@ return run.SlurmExecutor( account=slurm_cfg.account, partition=slurm_cfg.partition_cpu, nodes=nodes, tunnel=tunnel, container_image=slurm_cfg.container_image, container_mounts=slurm_cfg.container_mounts, time=slurm_cfg.time, packager=run.GitArchivePackager(), mem="0", + env_vars=slurm_cfg.env_vars, )To confirm parameter name on your installed nemo-run:
#!/usr/bin/env bash python - <<'PY' import inspect, nemo_run as run sig = inspect.signature(run.SlurmExecutor) print("SlurmExecutor params:", list(sig.parameters)) PYAlso applies to: 107-117
examples/nemo_run/qat/nemo_qat_flow.py (4)
138-145
: KV‑cache tri‑state: only add enable/disable flag when explicitly setCurrent logic always appends “--disable_kv_cache” when not enabling, which removes neutrality. Use a mutually exclusive group with default None and build the flag conditionally.
- parser.add_argument( - "--enable_kv_cache", - help="Enables KV-cache quantization", - action="store_true", - default=False, - ) + kv_group = parser.add_mutually_exclusive_group() + kv_group.add_argument("--enable_kv_cache", dest="enable_kv_cache", action="store_true", + help="Enable KV-cache quantization") + kv_group.add_argument("--disable_kv_cache", dest="enable_kv_cache", action="store_false", + help="Disable KV-cache quantization") + parser.set_defaults(enable_kv_cache=None) @@ - ptq = run.Script( + # Build KV-cache flag only when explicitly set + kv_cache_flag = ( + ["--enable_kv_cache"] if args.enable_kv_cache is True + else (["--disable_kv_cache"] if args.enable_kv_cache is False else []) + ) + ptq = run.Script( "/opt/NeMo/scripts/llm/ptq.py", args=[ @@ - "--kv_cache_qformat", - args.kv_cache_qformat, - "--enable_kv_cache" if args.enable_kv_cache else "--disable_kv_cache", + "--kv_cache_qformat", args.kv_cache_qformat, + *kv_cache_flag,Also applies to: 185-201
222-233
: Make constants robust when imported programmaticallyGuard against NameError if main() is called without main initializers.
def main(args): + # Fallbacks if constants are unset (e.g., when imported) + global SEQUENCE_LENGTH, MBS, GBS, TRAIN_STEPS, VAL_INTERVAL + SEQUENCE_LENGTH = globals().get("SEQUENCE_LENGTH", 4096) + MBS = globals().get("MBS", 1) + GBS = globals().get("GBS", 512) + TRAIN_STEPS = globals().get("TRAIN_STEPS", 200) + VAL_INTERVAL = globals().get("VAL_INTERVAL", 50)
283-286
: Honor --train-nodes on SlurmTraining executor currently hardcodes nodes=1. Pass
nodes=args.train_nodes
.- train_gpu_executor = create_slurm_executor( - SLURM_CONFIG, num_gpus=args.train_gpus, ntasks_per_node=args.train_gpus - ) + train_gpu_executor = create_slurm_executor( + SLURM_CONFIG, + nodes=args.train_nodes, + num_gpus=args.train_gpus, + ntasks_per_node=args.train_gpus, + )
338-346
: Use the single‑GPU executor for export; don’t mutate the train executorMutating
ntasks_per_node
is brittle and may be ignored by the backend.- # WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo - train_gpu_executor.ntasks_per_node = 1 # will throw error if more than 1 task during export - exp.add( - export, - tail_logs=True, - name="07_export_hf", - executor=train_gpu_executor, - dependencies=[s5], - ) + # Export with a dedicated single‑GPU executor + exp.add( + export, + tail_logs=True, + name="07_export_hf", + executor=single_gpu_executor, + dependencies=[s5], + )
🧹 Nitpick comments (8)
examples/nemo_run/qat/README.md (3)
59-63
: Pin NeMo commit with contextConsider adding a brief note why commit 676ed1a is required and the date/tag, or switch to a tag for stability.
What NeMo features fixed by 676ed1a are required here (PTQ CLI flags, training hooks, etc.)?
41-51
: Results reproducibility noteAdd hardware/software versions (driver/CUDA/NeMo/TMO) and random seed for reproducibility.
51-51
: Grammar nit“also is much smaller” → “is also much smaller.”
-The resulting exported checkpoint also is much smaller in memory at 6.4GB compared to the original BF16 checkpoint which is 16.4 GB. +The resulting exported checkpoint is also much smaller in memory: 6.4 GB vs 16.4 GB for the original BF16 checkpoint.examples/nemo_run/common/utils.py (2)
52-66
: Validate job_dir when using LocalTunnelLocalTunnel requires job_dir as well; add a check for
use_local_tunnel=True
.if not self.use_local_tunnel: # Only validate SSH tunnel settings if not using local tunnel @@ ) + else: + if not self.job_dir: + raise ValueError("SlurmConfig.job_dir must be set when use_local_tunnel is True")
131-139
: Optional: reuse existing helper to avoid duplication
download_hf_dataset
duplicates logic in examples/nemo_run/common/process_lima.py. Consider importing/reusing to DRY.examples/nemo_run/qat/nemo_qat_flow.py (3)
148-156
: Fix model/recipe naming confusion; use HF basename for output dirs
model_name = args.finetune_recipe
is misleading and the fallback never triggers. Use distinct names and drive paths from HF model basename.- model_name = args.finetune_recipe - model_module = getattr(llm, model_name) - if not model_name: - model_name = os.path.basename(args.model_name) - exp_dir = f"{args.log_dir.rstrip('/')}/{args.experiment}" + recipe_name = args.finetune_recipe + try: + model_module = getattr(llm, recipe_name) + except AttributeError as e: + raise ValueError(f"Unknown recipe: {recipe_name}") from e + hf_basename = os.path.basename(args.model_name) + model_name = hf_basename # used for directory naming + exp_dir = f"{args.log_dir.rstrip('/')}/{args.experiment}"
157-171
: Remove stale TODO and path ambiguity; rely on resolved pathsComments suggest uncertainty; the code already resolves absolute vs Slurm paths. Clean up.
- # 1. Process data - # TODO figure out path - # LOCALLY common/process.py works - # On slurm examples/nemo_run/common/process.py works + # 1. Process data
362-364
: Prefer HUGGING_FACE_HUB_TOKEN; keep HF_TOKEN as aliasAlign with README and common tooling.
- env_vars={ - "HF_TOKEN": "", - }, + env_vars={ + "HUGGING_FACE_HUB_TOKEN": "", # HF_TOKEN also recognized by some tools + "HF_TOKEN": "", + },
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
examples/llm_qat/README.md
(1 hunks)examples/nemo_run/common/in_memory_mmlu.py
(1 hunks)examples/nemo_run/common/process_openscience.py
(1 hunks)examples/nemo_run/common/utils.py
(1 hunks)examples/nemo_run/qat/ADVANCED.md
(1 hunks)examples/nemo_run/qat/README.md
(2 hunks)examples/nemo_run/qat/nemo_qat_flow.py
(6 hunks)modelopt/torch/export/plugins/nemo_run.py
(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- examples/llm_qat/README.md
🚧 Files skipped from review as they are similar to previous changes (4)
- examples/nemo_run/common/process_openscience.py
- examples/nemo_run/common/in_memory_mmlu.py
- examples/nemo_run/qat/ADVANCED.md
- modelopt/torch/export/plugins/nemo_run.py
🧰 Additional context used
🧬 Code graph analysis (2)
examples/nemo_run/common/utils.py (1)
examples/nemo_run/common/process_lima.py (1)
download_hf_dataset
(29-37)
examples/nemo_run/qat/nemo_qat_flow.py (3)
modelopt/torch/export/plugins/nemo_run.py (1)
export_most_recent_ckpt
(24-33)examples/nemo_run/common/utils.py (4)
SlurmConfig
(24-72)create_slurm_executor
(75-117)get_finetune_recipe
(120-123)read_chat_template
(126-128)examples/nemo_run/common/in_memory_mmlu.py (1)
parse_args
(24-45)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (3)
examples/nemo_run/qat/README.md (3)
19-39
: Flow stages list + diagram: aligned and clearEight stages match task names; diagram edges are coherent.
5-7
: Two links pointing to the same ADVANCED.mdIf intentional, rename one (e.g., “Slurm Guide | Advanced Topics”) or deduplicate.
[raise_nitpick_refactor]
66-71
: Doc inconsistent with AI summary and prior commitsThe file still shows the old docker guidance and KV-cache flag; AI summary/past comments indicate these were addressed. Update the README accordingly.
Also applies to: 86-86
Signed-off-by: Jennifer Chen <[email protected]> Signed-off-by: Ye Yu <[email protected]>
What does this PR do?
Type of change: new feature
Overview: Support for launching QAT/QAD Simplified Flow in slurm & Qwen3-8B QAT recipe
Usage
Testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Documentation